From c7a218c7aa6104eac202470a77cdb00e05e2badf Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Sat, 24 Feb 2024 14:01:07 +0000 Subject: [PATCH] Add SRNE and remove rejection ABC (#21) * Add SRNE * Remove rejection ABC --- README.md | 32 +- examples/bivariate_gaussian_snl.py | 2 +- examples/bivariate_gaussian_snp.py | 2 +- examples/bivariate_gaussian_snr.py | 72 ++++ ...riate_gaussian_snasss.py => slcp_snass.py} | 81 +++- examples/slcp_ssnl.py | 15 +- pyproject.toml | 10 +- sbijax/__init__.py | 4 +- sbijax/_src/_sne_base.py | 6 +- sbijax/_src/abc/rejection_abc.py | 88 ----- sbijax/_src/mcmc/__init__.py | 3 + sbijax/_src/nn/make_flows.py | 167 ++++++++ sbijax/_src/nn/make_resnet.py | 114 ++++++ sbijax/_src/snass.py | 2 +- sbijax/_src/snl.py | 19 +- sbijax/_src/snp.py | 4 +- sbijax/_src/snr.py | 368 ++++++++++++++++++ sbijax/mcmc/__init__.py | 7 + sbijax/nn/__init__.py | 5 + 19 files changed, 858 insertions(+), 143 deletions(-) create mode 100644 examples/bivariate_gaussian_snr.py rename examples/{bivariate_gaussian_snasss.py => slcp_snass.py} (51%) delete mode 100644 sbijax/_src/abc/rejection_abc.py create mode 100644 sbijax/_src/nn/make_flows.py create mode 100644 sbijax/_src/nn/make_resnet.py create mode 100644 sbijax/_src/snr.py create mode 100644 sbijax/mcmc/__init__.py diff --git a/README.md b/README.md index e29b69f..4731c36 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # sbijax -[![status](http://www.repostatus.org/badges/latest/concept.svg)](http://www.repostatus.org/#concept) +[![active](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active) [![ci](https://github.com/dirmeier/sbijax/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/sbijax/actions/workflows/ci.yaml) [![version](https://img.shields.io/pypi/v/sbijax.svg?colorB=black&style=flat)](https://pypi.org/project/sbijax/) @@ -8,22 +8,27 @@ ## About -SbiJAX implements several algorithms for simulation-based inference using -[JAX](https://github.com/google/jax), [Haiku](https://github.com/deepmind/dm-haiku) and [BlackJAX](https://github.com/blackjax-devs/blackjax). +`sbijax` implements several algorithms for simulation-based inference in +[JAX](https://github.com/google/jax) using [Haiku](https://github.com/deepmind/dm-haiku), +[Distrax](https://github.com/deepmind/distrax) and [BlackJAX](https://github.com/blackjax-devs/blackjax). Specifically, `sbijax` implements -SbiJAX so far implements +- [Sequential Monte Carlo ABC](https://www.routledge.com/Handbook-of-Approximate-Bayesian-Computation/Sisson-Fan-Beaumont/p/book/9780367733728) (`SMCABC`), +- [Neural Likelihood Estimation](https://arxiv.org/abs/1805.07226) (`SNL`) +- [Surjective Neural Likelihood Estimation](https://arxiv.org/abs/2308.01054) (`SSNL`) +- [Neural Posterior Estimation C](https://arxiv.org/abs/1905.07488) (short `SNP`) +- [Contrastive Neural Ratio Estimation](https://arxiv.org/abs/2210.06170) (short `SNR`) +- [Neural Approximate Sufficient Statistics](https://arxiv.org/abs/2010.10079) (`SNASS`) +- [Neural Approximate Slice Sufficient Statistics](https://openreview.net/forum?id=jjzJ768iV1) (`SNASSS`) -- Rejection ABC (`RejectionABC`), -- Sequential Monte Carlo ABC (`SMCABC`), -- Sequential Neural Likelihood Estimation (`SNL`) -- Surjective Sequential Neural Likelihood Estimation (`SSNL`) -- Sequential Neural Posterior Estimation C (short `SNP`) +where the acronyms in parentheses denote the names of the methods in `sbijax`. ## Examples -You can find several self-contained examples on how to use the algorithms in `examples`. +You can find several self-contained examples on how to use the algorithms in [examples](https://github.com/dirmeier/sbijax/tree/main/examples). -## Usage +## Documentation + +Documentation can be found [here](https://sbijax.readthedocs.io/en/latest/). ## Installation @@ -42,6 +47,11 @@ To install the latest GitHub , use: pip install git+https://github.com/dirmeier/sbijax@ ``` +## Acknowledgements + +> 📝 The package draws significant inspiration from the excellent Pytorch-based [`sbi`](https://github.com/sbi-dev/sbi) package which is substantially more +feature-complete and user-friendly, and better documented. + ## Author Simon Dirmeier sfyrbnd @ pm me diff --git a/examples/bivariate_gaussian_snl.py b/examples/bivariate_gaussian_snl.py index a38f145..b41407a 100644 --- a/examples/bivariate_gaussian_snl.py +++ b/examples/bivariate_gaussian_snl.py @@ -13,12 +13,12 @@ from jax import numpy as jnp from jax import random as jr from surjectors import ( - MADE, Chain, MaskedAutoregressive, Permutation, TransformedDistribution, ) +from surjectors.nn import MADE from surjectors.util import unstack from sbijax import SNL diff --git a/examples/bivariate_gaussian_snp.py b/examples/bivariate_gaussian_snp.py index bc058b8..2ad31b7 100644 --- a/examples/bivariate_gaussian_snp.py +++ b/examples/bivariate_gaussian_snp.py @@ -1,5 +1,5 @@ """ -Example using sequential posterior estimation on a bivariate Gaussian +Example using sequential neural posterior estimation on a bivariate Gaussian """ import distrax diff --git a/examples/bivariate_gaussian_snr.py b/examples/bivariate_gaussian_snr.py new file mode 100644 index 0000000..2da0eba --- /dev/null +++ b/examples/bivariate_gaussian_snr.py @@ -0,0 +1,72 @@ +""" +Example using sequential neural ratio estimation on a bivariate Gaussian +""" + +import distrax +import haiku as hk +import matplotlib.pyplot as plt +import optax +import seaborn as sns +from jax import numpy as jnp +from jax import random as jr + +from sbijax import SNR + + +def prior_model_fns(): + p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1) + return p.sample, p.log_prob + + +def simulator_fn(seed, theta): + p = distrax.Normal(jnp.zeros_like(theta), 1.0) + y = theta + p.sample(seed=seed) + return y + + +def make_model(): + @hk.without_apply_rng + @hk.transform + def _mlp(inputs, **kwargs): + return hk.nets.MLP([64, 64, 1])(inputs) + + return _mlp + + +def run(): + y_observed = jnp.array([2.0, -2.0]) + + prior_simulator_fn, prior_logdensity_fn = prior_model_fns() + fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn + + snr = SNR(fns, make_model()) + optimizer = optax.adam(1e-3) + + data, params = None, {} + for i in range(5): + data, _ = snr.simulate_data_and_possibly_append( + jr.fold_in(jr.PRNGKey(1), i), + params=params, + observable=y_observed, + data=data, + ) + params, info = snr.fit( + jr.fold_in(jr.PRNGKey(2), i), + data=data, + optimizer=optimizer, + batch_size=100, + ) + + rng_key = jr.PRNGKey(23) + snr_samples, _ = snr.sample_posterior(rng_key, params, y_observed) + fig, axes = plt.subplots(2) + for i, ax in enumerate(axes): + sns.histplot(snr_samples[:, i], color="darkblue", ax=ax) + ax.set_xlim([-3.0, 3.0]) + sns.despine() + plt.tight_layout() + plt.show() + + +if __name__ == "__main__": + run() diff --git a/examples/bivariate_gaussian_snasss.py b/examples/slcp_snass.py similarity index 51% rename from examples/bivariate_gaussian_snasss.py rename to examples/slcp_snass.py index 3e335d4..b8400b4 100644 --- a/examples/bivariate_gaussian_snasss.py +++ b/examples/slcp_snass.py @@ -1,6 +1,5 @@ """ -Example using sequential neural approximate (slice) summary statistics on a -bivariate Gaussian with repeated dimensions +Example SNASS on the SLCP experiment """ import distrax @@ -11,6 +10,7 @@ import seaborn as sns from jax import numpy as jnp from jax import random as jr +from jax import scipy as jsp from surjectors import ( Chain, MaskedAutoregressive, @@ -23,20 +23,64 @@ from sbijax import SNASSS from sbijax.nn import make_snasss_net -W = jr.normal(jr.PRNGKey(0), (2, 10)) - def prior_model_fns(): - p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1) + p = distrax.Independent( + distrax.Uniform(jnp.full(5, -3.0), jnp.full(5, 3.0)), 1 + ) return p.sample, p.log_prob def simulator_fn(seed, theta): - y = theta @ W - y = y + distrax.Normal(jnp.zeros_like(y), 0.1).sample(seed=seed) + orig_shape = theta.shape + if theta.ndim == 2: + theta = theta[:, None, :] + us_key, noise_key = jr.split(seed) + + def _unpack_params(ps): + m0 = ps[..., [0]] + m1 = ps[..., [1]] + s0 = ps[..., [2]] ** 2 + s1 = ps[..., [3]] ** 2 + r = jnp.tanh(ps[..., [4]]) + return m0, m1, s0, s1, r + + m0, m1, s0, s1, r = _unpack_params(theta) + us = distrax.Normal(0.0, 1.0).sample( + seed=us_key, sample_shape=(theta.shape[0], theta.shape[1], 4, 2) + ) + xs = jnp.empty_like(us) + xs = xs.at[:, :, :, 0].set(s0 * us[:, :, :, 0] + m0) + y = xs.at[:, :, :, 1].set( + s1 * (r * us[:, :, :, 0] + jnp.sqrt(1.0 - r**2) * us[:, :, :, 1]) + m1 + ) + if len(orig_shape) == 2: + y = y.reshape((*theta.shape[:1], 8)) + else: + y = y.reshape((*theta.shape[:2], 8)) return y +def likelihood_fn(theta, y): + mu = jnp.tile(theta[:2], 4) + s1, s2 = theta[2] ** 2, theta[3] ** 2 + corr = s1 * s2 * jnp.tanh(theta[4]) + cov = jnp.array([[s1**2, corr], [corr, s2**2]]) + cov = jsp.linalg.block_diag(*[cov for _ in range(4)]) + p = distrax.MultivariateNormalFullCovariance(mu, cov) + return p.log_prob(y) + + +def log_density_fn(theta, y): + prior_lp = distrax.Independent( + distrax.Uniform(jnp.full(5, -3.0), jnp.full(5, 3.0)), 1 + ).log_prob(theta) + likelihood_lp = likelihood_fn(theta, y) + + lp = jnp.sum(prior_lp) + jnp.sum(likelihood_lp) + return lp + + def make_model(dim): def _bijector_fn(params): means, log_scales = unstack(params, -1) @@ -50,7 +94,7 @@ def _flow(method, **kwargs): bijector_fn=_bijector_fn, conditioner=MADE( dim, - [50, 50, dim * 2], + [64, 64, dim * 2], 2, w_init=hk.initializers.TruncatedNormal(0.001), b_init=jnp.zeros, @@ -75,38 +119,39 @@ def _flow(method, **kwargs): def run(): - y_observed = jnp.array([[2.0, -2.0]]) @ W + thetas = jnp.linspace(-2.0, 2.0, 5) + y_0 = simulator_fn(jr.PRNGKey(0), thetas.reshape(-1, 5)).reshape(-1, 8) prior_simulator_fn, prior_logdensity_fn = prior_model_fns() fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn estim = SNASSS( fns, - make_model(2), - make_snasss_net([64, 64, 2], [64, 64, 1], [64, 64, 1]), + make_model(5), + make_snasss_net([64, 64, 5], [64, 64, 1], [64, 64, 1]), ) optimizer = optax.adam(1e-3) data, params = None, {} - for i in range(2): + for i in range(5): data, _ = estim.simulate_data_and_possibly_append( - jr.fold_in(jr.PRNGKey(1), i), + jr.fold_in(jr.PRNGKey(12), i), params=params, - observable=y_observed, + observable=y_0, data=data, ) params, _ = estim.fit( - jr.fold_in(jr.PRNGKey(2), i), + jr.fold_in(jr.PRNGKey(23), i), data=data, optimizer=optimizer, batch_size=100, ) rng_key = jr.PRNGKey(23) - snp_samples, _ = estim.sample_posterior(rng_key, params, y_observed) - fig, axes = plt.subplots(2) + snasss_samples, _ = estim.sample_posterior(rng_key, params, y_0) + fig, axes = plt.subplots(5) for i, ax in enumerate(axes): - sns.histplot(snp_samples[:, i], color="darkblue", ax=ax) + sns.histplot(snasss_samples[:, i], color="darkblue", ax=ax) ax.set_xlim([-3.0, 3.0]) sns.despine() plt.tight_layout() diff --git a/examples/slcp_ssnl.py b/examples/slcp_ssnl.py index 0be7b35..662ebdd 100644 --- a/examples/slcp_ssnl.py +++ b/examples/slcp_ssnl.py @@ -4,13 +4,11 @@ import argparse from functools import partial -from timeit import default_timer as timer import distrax import haiku as hk import jax import matplotlib.pyplot as plt -import numpy as np import optax import pandas as pd import seaborn as sns @@ -25,11 +23,11 @@ Permutation, TransformedDistribution, ) -from surjectors.conditioners import MADE, mlp_conditioner +from surjectors.nn import MADE, make_mlp from surjectors.util import unstack from sbijax import SNL -from sbijax.mcmc.slice import sample_with_slice +from sbijax.mcmc import sample_with_slice def prior_model_fns(): @@ -50,7 +48,7 @@ def _unpack_params(ps): m1 = ps[..., [1]] s0 = ps[..., [2]] ** 2 s1 = ps[..., [3]] ** 2 - r = np.tanh(ps[..., [4]]) + r = jnp.tanh(ps[..., [4]]) return m0, m1, s0, s1, r m0, m1, s0, s1, r = _unpack_params(theta) @@ -60,7 +58,7 @@ def _unpack_params(ps): xs = jnp.empty_like(us) xs = xs.at[:, :, :, 0].set(s0 * us[:, :, :, 0] + m0) y = xs.at[:, :, :, 1].set( - s1 * (r * us[:, :, :, 0] + np.sqrt(1.0 - r**2) * us[:, :, :, 1]) + m1 + s1 * (r * us[:, :, :, 0] + jnp.sqrt(1.0 - r**2) * us[:, :, :, 1]) + m1 ) if len(orig_shape) == 2: y = y.reshape((*theta.shape[:1], 8)) @@ -95,7 +93,7 @@ def _bijector_fn(params): return distrax.ScalarAffine(means, jnp.exp(log_scales)) def _decoder_fn(n_dim): - decoder_net = mlp_conditioner( + decoder_net = make_mlp( [50, n_dim * 2], w_init=hk.initializers.TruncatedNormal(stddev=0.001), ) @@ -189,7 +187,6 @@ def run(use_surjectors): optimizer = optax.adam(1e-3) data, params = None, {} - start = timer() for i in range(5): data, _ = snl.simulate_data_and_possibly_append( jr.fold_in(jr.PRNGKey(12), i), @@ -201,8 +198,6 @@ def run(use_surjectors): params, info = snl.fit( jr.fold_in(jr.PRNGKey(23), i), data=data, optimizer=optimizer ) - end = timer() - print(end - start) sample_key, rng_key = jr.split(jr.PRNGKey(123)) snl_samples, _ = snl.sample_posterior(sample_key, params, y_observed) diff --git a/pyproject.toml b/pyproject.toml index 1c476d2..2cc200f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,8 @@ dependencies = [ "dm-haiku>=0.0.9", "optax>=0.1.3", "surjectors>=0.3.0", - "tfp-nightly>=0.20.0.dev20230404" + "tfp-nightly>=0.20.0.dev20230404", + "tqdm>=4.64.1" ] dynamic = ["version"] @@ -46,6 +47,12 @@ exclude = [ "/.pre-commit-config.yaml" ] +[tool.hatch.envs.examples] +dependencies = [ + "matplotlib>=3.6.1", + "seaborn>=0.12.2" +] + [tool.hatch.envs.test] dependencies = [ "pylint>=2.15.10", @@ -57,7 +64,6 @@ dependencies = [ lint = 'pylint sbijax' test = 'pytest -v --doctest-modules --cov=./sbijax --cov-report=xml sbijax' - [tool.black] line-length = 80 target-version = ['py311'] diff --git a/sbijax/__init__.py b/sbijax/__init__.py index 6357a09..7cfc29f 100644 --- a/sbijax/__init__.py +++ b/sbijax/__init__.py @@ -2,12 +2,12 @@ sbijax: Simulation-based inference in JAX """ -__version__ = "0.1.5" +__version__ = "0.1.6" -from sbijax._src.abc.rejection_abc import RejectionABC from sbijax._src.abc.smc_abc import SMCABC from sbijax._src.snass import SNASS from sbijax._src.snasss import SNASSS from sbijax._src.snl import SNL from sbijax._src.snp import SNP +from sbijax._src.snr import SNR diff --git a/sbijax/_src/_sne_base.py b/sbijax/_src/_sne_base.py index 0fc377b..c69e17a 100644 --- a/sbijax/_src/_sne_base.py +++ b/sbijax/_src/_sne_base.py @@ -13,15 +13,15 @@ class SNE(SBI, ABC): """Sequential neural estimation base class.""" - def __init__(self, model_fns, density_estimator): + def __init__(self, model_fns, network): """Construct an SNE object. Args: model_fns: tuple - density_estimator: maf + network: maf """ super().__init__(model_fns) - self.model = density_estimator + self.model = network self.n_total_simulations = 0 def simulate_data_and_possibly_append( diff --git a/sbijax/_src/abc/rejection_abc.py b/sbijax/_src/abc/rejection_abc.py deleted file mode 100644 index d8f9506..0000000 --- a/sbijax/_src/abc/rejection_abc.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import Callable, Tuple - -from jax import numpy as jnp -from jax import random as jr - -from sbijax._src._sbi_base import SBI - - -# pylint: disable=too-many-instance-attributes,too-many-arguments -# pylint: disable=too-many-locals,too-few-public-methods, -class RejectionABC(SBI): - """Rejection approximate Bayesian computation. - - Implements algorithm~4.1 from [1]. - - References: - .. [1] Sisson, Scott A, et al. "Handbook of approximate Bayesian - computation". 2019 - """ - - def __init__( - self, model_fns: Tuple, summary_fn: Callable, kernel_fn: Callable - ): - """Constructs a RejectionABC object. - - Args: - model_fns: tuple - summary_fn: summary statistice function - kernel_fn: a kernel function to compute similarities - """ - super().__init__(model_fns) - self.kernel_fn = kernel_fn - self.summary_fn = summary_fn - - # pylint: disable=arguments-differ - def sample_posterior( - self, - rng_key, - observable, - n_samples, - n_simulations_per_theta, - K, - h, - **kwargs, - ): - r"""Sample from the approximate posterior. - - Args: - rng_key: a random key - observable: observation to condition on - n_samples: number of samples to draw for each parameter - n_simulations_per_theta: number of simulations for each parameter - sample - K: normalisation parameter - h: kernel scale - - Returns: - an array of samples from the posterior distribution of dimension - (n_samples \times p) - """ - observable = jnp.atleast_2d(observable) - - thetas = None - n = n_samples - K = jnp.maximum( - K, self.kernel_fn(jnp.zeros((1, 2, 2)), jnp.zeros((1, 2, 2)))[0] - ) - while n > 0: - p_key, simulate_key, prior_key, rng_key = jr.split(rng_key) - n_sim = jnp.minimum(n, 1000) - ps = self.prior_sampler_fn(seed=prior_key, sample_shape=(n_sim,)) - ys = self.simulator_fn( - seed=simulate_key, - theta=jnp.tile(ps, [n_simulations_per_theta, 1, 1]), - ) - ys = jnp.swapaxes(ys, 1, 0) - k = self.kernel_fn( - self.summary_fn(ys), self.summary_fn(observable), h - ) - p = jr.uniform(p_key, shape=(len(k),)) - mr = k / K - idxs = jnp.where(p < mr)[0] - if thetas is None: - thetas = ps[idxs] - else: - thetas = jnp.vstack([thetas, ps[idxs]]) - n -= len(idxs) - return thetas[:n_samples, :] diff --git a/sbijax/_src/mcmc/__init__.py b/sbijax/_src/mcmc/__init__.py index d26d3dc..3ddef25 100644 --- a/sbijax/_src/mcmc/__init__.py +++ b/sbijax/_src/mcmc/__init__.py @@ -1,3 +1,6 @@ +from sbijax._src.mcmc.irmh import sample_with_imh +from sbijax._src.mcmc.mala import sample_with_mala from sbijax._src.mcmc.nuts import sample_with_nuts +from sbijax._src.mcmc.rmh import sample_with_rmh from sbijax._src.mcmc.sample import mcmc_diagnostics from sbijax._src.mcmc.slice import sample_with_slice diff --git a/sbijax/_src/nn/make_flows.py b/sbijax/_src/nn/make_flows.py new file mode 100644 index 0000000..2ab3ca4 --- /dev/null +++ b/sbijax/_src/nn/make_flows.py @@ -0,0 +1,167 @@ +from typing import Callable, Iterable, List + +import distrax +import haiku as hk +import jax +from jax import numpy as jnp +from surjectors import ( + AffineMaskedAutoregressiveInferenceFunnel, + Chain, + MaskedAutoregressive, + Permutation, + TransformedDistribution, +) +from surjectors._src.conditioners.mlp import make_mlp +from surjectors._src.conditioners.nn.made import MADE +from surjectors.util import unstack + + +def _bijector_fn(params): + means, log_scales = unstack(params, -1) + return distrax.ScalarAffine(means, jnp.exp(log_scales)) + + +def _decoder_fn(n_dim, hidden_size): + decoder_net = make_mlp( + hidden_size + [n_dim * 2], + w_init=hk.initializers.TruncatedNormal(stddev=0.001), + ) + + def _fn(z): + params = decoder_net(z) + mu, log_scale = jnp.split(params, 2, -1) + return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale)), 1) + + return _fn + + +# pylint: disable=too-many-arguments +def make_affine_maf( + n_dimension: int, + n_layers: int = 5, + hidden_sizes: Iterable[int] = (64, 64), + activation: Callable = jax.nn.tanh, +): + """Create an affine masked autoregressive flow. + + Args: + n_dimension: dimensionality of data + n_layers: number of normalizing flow layers + hidden_sizes: sizes of hidden layers for each normalizing flow + activation: a jax activation function + + Returns: + a normalizing flow model + """ + + @hk.without_apply_rng + @hk.transform + def _flow(method, **kwargs): + layers = [] + order = jnp.arange(n_dimension) + for _ in range(n_layers): + layer = MaskedAutoregressive( + bijector_fn=_bijector_fn, + conditioner=MADE( + n_dimension, + list(hidden_sizes) + [n_dimension * 2], + 2, + w_init=hk.initializers.TruncatedNormal(0.001), + b_init=jnp.zeros, + activation=activation, + ), + ) + order = order[::-1] + layers.append(layer) + layers.append(Permutation(order, 1)) + chain = Chain(layers[:-1]) + + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(n_dimension), jnp.ones(n_dimension)), + 1, + ) + td = TransformedDistribution(base_distribution, chain) + return td(method, **kwargs) + + return _flow + + +def make_surjective_affine_maf( + n_dimension: int, + n_layer_dimensions: List[int], + n_layers: int = 5, + hidden_sizes: Iterable[int] = (64, 64), + activation: Callable = jax.nn.tanh, +): + """Create a surjective affine masked autoregressive flow. + + Args: + n_dimension: a list of integers that determine the dimensionality + of each flow layer + n_layer_dimensions: list of integers that determine if a layer is + dimensionality-preserving or -reducing + n_layers: number of normalizing flow layers + hidden_sizes: sizes of hidden layers for each normalizing flow + activation: a jax activation function + + Returns: + a normalizing flow model + """ + + @hk.without_apply_rng + @hk.transform + def _flow(method, **kwargs): + layers = [] + order = jnp.arange(n_dimension) + curr_dim = n_dimension + for i, n_dim_curr_layer in zip( + range(n_layers[:-1]), n_layer_dimensions[:-1] + ): + # layer is dimensionality preserving + if n_dim_curr_layer == curr_dim: + layer = MaskedAutoregressive( + bijector_fn=_bijector_fn, + conditioner=MADE( + n_dim_curr_layer, + list(hidden_sizes) + [n_dim_curr_layer * 2], + 2, + w_init=hk.initializers.TruncatedNormal(0.001), + b_init=jnp.zeros, + activation=activation, + ), + ) + order = order[::-1] + elif n_dim_curr_layer < curr_dim: + n_latent = n_dim_curr_layer + layer = AffineMaskedAutoregressiveInferenceFunnel( + n_latent, + _decoder_fn(curr_dim - n_latent, hidden_sizes), + conditioner=MADE( + n_latent, + hidden_sizes + [n_dim_curr_layer * 2], + 2, + w_init=hk.initializers.TruncatedNormal(0.001), + b_init=jnp.zeros, + activation=jax.nn.tanh, + ), + ) + curr_dim = n_latent + order = order[::-1] + order = order[:curr_dim] - jnp.min(order[:curr_dim]) + else: + raise ValueError( + f"n_dimension at layer {i} is layer than the dimension of" + f" the following layer {i + 1}" + ) + layers.append(layer) + layers.append(Permutation(order, 1)) + chain = Chain(layers[:-1]) + + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(n_dimension), jnp.ones(n_dimension)), + 1, + ) + td = TransformedDistribution(base_distribution, chain) + return td(method, **kwargs) + + return _flow diff --git a/sbijax/_src/nn/make_resnet.py b/sbijax/_src/nn/make_resnet.py new file mode 100644 index 0000000..bfe163f --- /dev/null +++ b/sbijax/_src/nn/make_resnet.py @@ -0,0 +1,114 @@ +from typing import Callable + +import haiku as hk +import jax + + +# pylint: disable=too-many-arguments +class _ResnetBlock(hk.Module): + """A block for a 1d residual network.""" + + def __init__( + self, + hidden_size: int, + activation: Callable = jax.nn.relu, + dropout_rate: float = 0.2, + do_batch_norm: bool = False, + batch_norm_decay: float = 0.1, + ): + super().__init__() + self.hidden_size = hidden_size + self.activation = activation + self.do_batch_norm = do_batch_norm + self.dropout_rate = dropout_rate + self.batch_norm_decay = batch_norm_decay + + def __call__(self, inputs, is_training=False): + outputs = inputs + if self.do_batch_norm: + outputs = hk.BatchNorm(True, True, self.batch_norm_decay)( + outputs, is_training=is_training + ) + outputs = hk.Linear(self.hidden_size)(outputs) + outputs = self.activation(outputs) + if is_training: + outputs = hk.dropout( + rng=hk.next_rng_key(), rate=self.dropout_rate, x=outputs + ) + outputs = hk.Linear(self.hidden_size)(outputs) + return outputs + inputs + + +# pylint: disable=too-many-arguments +class _Resnet(hk.Module): + """A simplified 1-d residual network.""" + + def __init__( + self, + n_layers: int, + hidden_size: int, + activation: Callable = jax.nn.relu, + dropout_rate: float = 0.2, + do_batch_norm: bool = True, + batch_norm_decay: float = 0.1, + ): + super().__init__() + self.n_layers = n_layers + self.hidden_size = hidden_size + self.activation = activation + self.do_batch_norm = do_batch_norm + self.dropout_rate = dropout_rate + self.batch_norm_decay = batch_norm_decay + + def __call__(self, inputs, is_training=False, **kwargs): + outputs = inputs + outputs = hk.Linear(self.hidden_size)(outputs) + outputs = self.activation(outputs) + for _ in range(self.n_layers): + outputs = _ResnetBlock( + hidden_size=self.hidden_size, + activation=self.activation, + dropout_rate=self.dropout_rate, + do_batch_norm=self.do_batch_norm, + batch_norm_decay=self.batch_norm_decay, + )(outputs, is_training=is_training) + outputs = self.activation(outputs) + outputs = hk.Linear(1)(outputs) + return outputs + + +def make_resnet( + n_layers: int = 2, + hidden_size: int = 64, + activation: Callable = jax.nn.tanh, + dropout_rate=0.2, + do_batch_norm=False, + batch_norm_decay=0.2, +): + """Create a resnet. + + Args: + n_layers: number of normalizing flow layers + hidden_size: sizes of hidden layers for each normalizing flow + activation: a jax activation function + dropout_rate: dropout rate to use in resnet blocks + do_batch_norm: use batch normalization or not + batch_norm_decay: decay rate of EMA in batch norm layer + Returns: + a neural network model + """ + + @hk.without_apply_rng + @hk.transform + def _net(inputs, is_training=False): + nn = _Resnet( + n_layers=n_layers, + hidden_size=hidden_size, + activation=activation, + do_batch_norm=do_batch_norm, + dropout_rate=dropout_rate, + batch_norm_decay=batch_norm_decay, + ) + return nn(inputs, is_training=is_training) + + return _net diff --git a/sbijax/_src/snass.py b/sbijax/_src/snass.py index 9968c4f..3692de9 100644 --- a/sbijax/_src/snass.py +++ b/sbijax/_src/snass.py @@ -34,7 +34,7 @@ class SNASS(SNL): """Sequential neural approximate summary statistics. References: - .. [1] Yanzhi Chen et al. "Neural Approximate Sufficient Statistics for + .. [1] Chen, Yanzhi et al. "Neural Approximate Sufficient Statistics for Implicit Models". ICLR, 2021 """ diff --git a/sbijax/_src/snl.py b/sbijax/_src/snl.py index 92480bf..0f47caf 100644 --- a/sbijax/_src/snl.py +++ b/sbijax/_src/snl.py @@ -24,7 +24,16 @@ class SNL(SNE): """Sequential neural likelihood. - From the Papamakarios paper + Implements SNL and SSNL estimation methods. + + References: + .. [1] Papamakarios, George, et al. "Sequential neural likelihood: + Fast likelihood-free inference with autoregressive flows." + International Conference on Artificial Intelligence and Statistics, + 2019. + .. [2] Dirmeier, Simon, et al. "Simulation-based inference using + surjective sequential neural likelihood estimation." + arXiv preprint arXiv:2308.01054, 2023. """ # pylint: disable=arguments-differ,too-many-locals @@ -128,7 +137,7 @@ def loss_fn(params): best_loss = validation_loss best_params = params.copy() - losses = jnp.vstack(losses)[:i, :] + losses = jnp.vstack(losses)[: (i + 1), :] return best_params, losses def _validation_loss(self, params, val_iter): @@ -237,9 +246,9 @@ def sample_posterior( rng_key, params, observable, - n_chains=4, - n_samples=2_000, - n_warmup=1_000, + n_chains=n_chains, + n_samples=n_samples, + n_warmup=n_warmup, **kwargs, ) diff --git a/sbijax/_src/snp.py b/sbijax/_src/snp.py index c822c12..67613ec 100644 --- a/sbijax/_src/snp.py +++ b/sbijax/_src/snp.py @@ -17,7 +17,9 @@ class SNP(SNE): """Sequential neural posterior estimation. References: - .. [1] + .. [1] Greenberg, David, et al. "Automatic posterior transformation for + likelihood-free inference." International Conference on Machine + Learning, 2019. """ def __init__(self, model_fns, density_estimator): diff --git a/sbijax/_src/snr.py b/sbijax/_src/snr.py new file mode 100644 index 0000000..4a24bf9 --- /dev/null +++ b/sbijax/_src/snr.py @@ -0,0 +1,368 @@ +# Parts of this codebase have been adopted from https://github.com/bkmi/cnre + +from functools import partial + +import chex +import jax +import numpy as np +import optax +from absl import logging +from jax import numpy as jnp +from jax import random as jr +from jax import scipy as jsp +from tqdm import tqdm + +from sbijax._src import mcmc +from sbijax._src._sne_base import SNE +from sbijax._src.mcmc import mcmc_diagnostics +from sbijax._src.util.early_stopping import EarlyStopping + + +def _get_prior_probs_marginal_and_joint(K, gamma): + p_marginal = 1 / (1 + gamma * K) + p_joint = gamma / (1 + gamma * K) + return p_marginal, p_joint + + +# pylint: disable=too-many-arguments +def _as_logits(params, rng_key, model, K, theta, y): + n = theta.shape[0] + y = jnp.repeat(y, K + 1, axis=0) + ps = jnp.ones((n, n)) * (1.0 - jnp.eye(n)) / (n - 1.0) + + choices = jax.vmap( + lambda key, p: jr.choice(key, n, (K,), replace=False, p=p) + )(jr.split(rng_key, n), ps) + + contrasting_theta = theta[choices] + atomic_theta = jnp.concatenate( + [theta[:, None, :], contrasting_theta], axis=1 + ).reshape(n * (K + 1), -1) + + inputs = jnp.concatenate([y, atomic_theta], axis=-1) + return model.apply(params, inputs, is_training=False) + + +def _marginal_joint_loss(gamma, num_classes, log_marg, log_joint): + loggamma = jnp.log(gamma) + logK = jnp.full((log_marg.shape[0], 1), jnp.log(num_classes)) + + denominator_marginal = jnp.concatenate( + [loggamma + log_marg, logK], + axis=-1, + ) + denomintator_joint = jnp.concatenate( + [loggamma + log_joint, logK], + axis=-1, + ) + + log_prob_marginal = logK - jsp.special.logsumexp( + denominator_marginal, axis=-1 + ) + log_prob_joint = ( + loggamma + + log_joint[:, 0] + - jsp.special.logsumexp(denomintator_joint, axis=-1) + ) + + p_marg, p_joint = _get_prior_probs_marginal_and_joint(num_classes, gamma) + loss = p_marg * log_prob_marginal + p_joint * num_classes * log_prob_joint + return loss + + +def _loss(params, rng_key, model, gamma, num_classes, **batch): + n, _ = batch["y"].shape + + rng_key1, rng_key2, rng_key = jr.split(rng_key, 3) + log_marg = _as_logits(params, rng_key1, model, num_classes, **batch) + log_joint = _as_logits(params, rng_key2, model, num_classes, **batch) + + log_marg = log_marg.reshape(n, num_classes + 1)[:, 1:] + log_joint = log_joint.reshape(n, num_classes + 1)[:, :-1] + + loss = _marginal_joint_loss(gamma, num_classes, log_marg, log_joint) + return -jnp.mean(loss) + + +# pylint: disable=too-many-arguments,unused-argument +class SNR(SNE): + """Sequential (contrastive) neural ratio estimation. + + References: + .. [1] Miller, Benjamin K., et al. "Contrastive neural ratio + estimation." Advances in Neural Information Processing Systems, 2022. + """ + + def __init__(self, model_fns, classifier, num_classes=10, gamma=1.0): + """Construct an SNP object. + + Args: + model_fns: tuple + classifier: a neural network for classification + num_classes: int + gamma: float + """ + super().__init__(model_fns, classifier) + self.gamma = gamma + self.num_classes = num_classes + + # pylint: disable=arguments-differ,too-many-locals + def fit( + self, + rng_key, + data, + *, + optimizer=optax.adam(0.003), + n_iter=1000, + batch_size=100, + percentage_data_as_validation_set=0.1, + n_early_stopping_patience=10, + **kwargs, + ): + """Fit an SNPE model. + + Args: + rng_key: a hk.PRNGSequence + data: data set obtained from calling + `simulate_data_and_possibly_append` + optimizer: an optax optimizer object + n_iter: maximal number of training iterations per round + batch_size: batch size used for training the model + percentage_data_as_validation_set: percentage of the simulated + data that is used for valitation and early stopping + n_early_stopping_patience: number of iterations of no improvement + of training the flow before stopping optimisation + n_atoms: number of atoms to approximate the proposal posterior + + Returns: + a tuple of parameters and a tuple of the training information + """ + itr_key, rng_key = jr.split(rng_key) + train_iter, val_iter = self.as_iterators( + itr_key, data, batch_size, percentage_data_as_validation_set + ) + params, losses = self._fit_model_single_round( + seed=rng_key, + train_iter=train_iter, + val_iter=val_iter, + optimizer=optimizer, + n_iter=n_iter, + n_early_stopping_patience=n_early_stopping_patience, + ) + + return params, losses + + # pylint: disable=undefined-loop-variable + def _fit_model_single_round( + self, + seed, + train_iter, + val_iter, + optimizer, + n_iter, + n_early_stopping_patience, + ): + init_key, seed = jr.split(seed) + params = self._init_params(init_key, **train_iter(0)) + state = optimizer.init(params) + + loss_fn = partial(_loss, gamma=self.gamma, num_classes=self.num_classes) + + @jax.jit + def step(params, rng, state, **batch): + loss, grads = jax.value_and_grad(loss_fn)( + params, rng, self.model, **batch + ) + updates, new_state = optimizer.update(grads, state, params) + new_params = optax.apply_updates(params, updates) + return loss, new_params, new_state + + losses = np.zeros([n_iter, 2]) + early_stop = EarlyStopping(1e-3, n_early_stopping_patience) + best_params, best_loss = None, np.inf + logging.info("training model") + for i in tqdm(range(n_iter)): + train_loss = 0.0 + rng_key = jr.fold_in(seed, i) + for j in range(train_iter.num_batches): + train_key, rng_key = jr.split(rng_key) + batch = train_iter(j) + batch_loss, params, state = step( + params, train_key, state, **batch + ) + train_loss += batch_loss * ( + batch["y"].shape[0] / train_iter.num_samples + ) + val_key, rng_key = jr.split(rng_key) + validation_loss = self._validation_loss(val_key, params, val_iter) + losses[i] = jnp.array([train_loss, validation_loss]) + + _, early_stop = early_stop.update(validation_loss) + if early_stop.should_stop: + logging.info("early stopping criterion found") + break + if validation_loss < best_loss: + best_loss = validation_loss + best_params = params.copy() + + losses = jnp.vstack(losses)[: (i + 1), :] + + return best_params, losses + + def _init_params(self, rng_key, **init_data): + params = self.model.init( + rng_key, + jnp.concatenate([init_data["y"], init_data["theta"]], axis=-1), + ) + return params + + def _validation_loss(self, rng_key, params, val_iter): + loss_fn = partial(_loss, gamma=self.gamma, num_classes=self.num_classes) + + @jax.jit + def body_fn(rng_key, **batch): + loss = loss_fn(params, rng_key, self.model, **batch) + return loss * (batch["y"].shape[0] / val_iter.num_samples) + + loss = 0.0 + for i in range(val_iter.num_batches): + val_key, rng_key = jr.split(rng_key) + loss += body_fn(val_key, **val_iter(i)) + return loss + + def simulate_data_and_possibly_append( + self, + rng_key, + params=None, + observable=None, + data=None, + n_simulations=1_000, + n_chains=4, + n_samples=2_000, + n_warmup=1_000, + **kwargs, + ): + """Simulate data from the prior or posterior. + + Args: + rng_key: a random key + params: a dictionary of neural network parameters + observable: an observation + data: existing data set + n_simulations: number of newly simulated data + n_chains: number of MCMC chains + n_samples: number of sa les to draw in total + n_warmup: number of draws to discared + kwargs: keyword arguments + dictionary of ey value pairs passed to `sample_posterior`. + The following arguments are possible: + - sampler: either 'nuts', 'slice' or None (defaults to nuts) + - n_thin: number of thinning steps (int) + - n_doubling: number of doubling steps of the interval (int) + - step_size: step size of the initial interval (float) + + Returns: + returns a NamedTuple of two axis, y and theta + """ + return super().simulate_data_and_possibly_append( + rng_key=rng_key, + params=params, + observable=observable, + data=data, + n_simulations=n_simulations, + n_chains=n_chains, + n_samples=n_samples, + n_warmup=n_warmup, + **kwargs, + ) + + def sample_posterior( + self, + rng_key, + params, + observable, + *, + n_chains=4, + n_samples=2_000, + n_warmup=1_000, + **kwargs, + ): + r"""Sample from the approximate posterior. + + Args: + rng_key: a random key + params: a pytree of parameter for the model + observable: observation to condition on + n_chains: number of MCMC chains + n_samples: number of samples per chain + n_warmup: number of samples to discard + kwargs: keyword arguments with sampler specific parameters. For + slice sampling the following arguments are possible: + - sampler: either 'nuts', 'slice' or None (defaults to nuts) + - n_thin: number of thinning steps + - n_doubling: number of doubling steps of the interval + - step_size: step size of the initial interval + + Returns: + an array of samples from the posterior distribution of dimension + (n_samples \times p) + """ + observable = jnp.atleast_2d(observable) + return self._sample_posterior( + rng_key, + params, + observable, + n_chains=n_chains, + n_samples=n_samples, + n_warmup=n_warmup, + **kwargs, + ) + + def _sample_posterior( + self, + rng_key, + params, + observable, + *, + n_chains=4, + n_samples=2_000, + n_warmup=1_000, + **kwargs, + ): + part = partial(self.model.apply, params, is_training=False) + + def _joint_logdensity_fn(theta): + lp_prior = self.prior_log_density_fn(theta) + theta = theta.reshape(observable.shape) + lp = part(jnp.concatenate([observable, theta], axis=-1)) + return jnp.sum(lp_prior) + jnp.sum(lp) + + if "sampler" in kwargs and kwargs["sampler"] == "slice": + + def lp__(theta): + return jax.vmap(_joint_logdensity_fn)(theta) + + sampler = kwargs.pop("sampler", None) + else: + + def lp__(theta): + return _joint_logdensity_fn(**theta) + + # take whatever sampler is or per default nuts + sampler = kwargs.pop("sampler", "nuts") + + sampling_fn = getattr(mcmc, "sample_with_" + sampler) + samples = sampling_fn( + rng_key=rng_key, + lp=lp__, + prior=self.prior_sampler_fn, + n_chains=n_chains, + n_samples=n_samples, + n_warmup=n_warmup, + **kwargs, + ) + chex.assert_shape(samples, [n_samples - n_warmup, n_chains, None]) + diagnostics = mcmc_diagnostics(samples) + samples = samples.reshape((n_samples - n_warmup) * n_chains, -1) + + return samples, diagnostics diff --git a/sbijax/mcmc/__init__.py b/sbijax/mcmc/__init__.py new file mode 100644 index 0000000..877acc8 --- /dev/null +++ b/sbijax/mcmc/__init__.py @@ -0,0 +1,7 @@ +"""MCMC module.""" + +from sbijax._src.mcmc.irmh import sample_with_imh +from sbijax._src.mcmc.mala import sample_with_mala +from sbijax._src.mcmc.nuts import sample_with_nuts +from sbijax._src.mcmc.rmh import sample_with_rmh +from sbijax._src.mcmc.slice import sample_with_slice diff --git a/sbijax/nn/__init__.py b/sbijax/nn/__init__.py index 1ca8c87..c9082f7 100644 --- a/sbijax/nn/__init__.py +++ b/sbijax/nn/__init__.py @@ -1,3 +1,8 @@ """Neural network module.""" +from sbijax._src.nn.make_flows import ( + make_affine_maf, + make_surjective_affine_maf, +) +from sbijax._src.nn.make_resnet import make_resnet from sbijax._src.nn.make_snass_networks import make_snass_net, make_snasss_net