diff --git a/examples/bivariate_gaussian_snp.py b/examples/bivariate_gaussian_snp.py new file mode 100644 index 0000000..575d5a2 --- /dev/null +++ b/examples/bivariate_gaussian_snp.py @@ -0,0 +1,101 @@ +""" +Example using SNL and masked coupling flows +""" + +import distrax +import haiku as hk +import jax.nn +import matplotlib.pyplot as plt +import optax +import seaborn as sns +from jax import numpy as jnp +from jax import random +from surjectors import Chain, TransformedDistribution +from surjectors.bijectors.masked_autoregressive import MaskedAutoregressive +from surjectors.bijectors.permutation import Permutation +from surjectors.conditioners import MADE +from surjectors.util import unstack + +from sbijax import SNP + + +def prior_model_fns(): + p = distrax.Independent( + distrax.Uniform(-2 * jnp.ones(2), 2 * jnp.ones(2)), 1 + ) + return p.sample, p.log_prob + + +def simulator_fn(seed, theta): + p = distrax.Normal(jnp.zeros_like(theta), 0.1) + y = theta + p.sample(seed=seed) + return y + + +def make_model(dim): + def _bijector_fn(params): + means, log_scales = unstack(params, -1) + return distrax.ScalarAffine(means, jnp.exp(log_scales)) + + def _flow(method, **kwargs): + layers = [] + order = jnp.arange(dim) + for i in range(5): + layer = MaskedAutoregressive( + bijector_fn=_bijector_fn, + conditioner=MADE( + dim, + [50, 50, dim * 2], + 2, + w_init=hk.initializers.TruncatedNormal(0.001), + b_init=jnp.zeros, + activation=jax.nn.tanh, + ), + ) + order = order[::-1] + layers.append(layer) + layers.append(Permutation(order, 1)) + chain = Chain(layers) + + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(dim), jnp.ones(dim)), + 1, + ) + td = TransformedDistribution(base_distribution, chain) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def run(): + y_observed = jnp.array([-2.0, 1.0]) + + prior_simulator_fn, prior_logdensity_fn = prior_model_fns() + fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn + + optimizer = optax.chain(optax.clip(5.0), optax.adamw(1e-04)) + snp = SNP(fns, make_model(2)) + params, info = snp.fit( + random.PRNGKey(2), + y_observed, + n_rounds=5, + optimizer=optimizer, + n_early_stopping_patience=10, + batch_size=128, + n_atoms=10, + max_iter=200, + ) + + snp_samples, _ = snp.sample_posterior(params, 10000) + fig, axes = plt.subplots(2) + for i, ax in enumerate(axes): + sns.histplot(snp_samples[:, i], color="darkblue", ax=ax) + ax.set_xlim([-2.0, 2.0]) + sns.despine() + plt.tight_layout() + plt.show() + + +if __name__ == "__main__": + run() diff --git a/examples/slcp_snp.py b/examples/slcp_snp.py new file mode 100644 index 0000000..bde5db1 --- /dev/null +++ b/examples/slcp_snp.py @@ -0,0 +1,202 @@ +""" +SLCP example from [1] using SNL and masked coupling bijections or surjections +""" + +import argparse +from functools import partial + +import distrax +import haiku as hk +import matplotlib.pyplot as plt +import numpy as np +import optax +import pandas as pd +import seaborn as sns +from jax import numpy as jnp +from jax import random +from jax import scipy as jsp +from surjectors import ( + AffineMaskedCouplingInferenceFunnel, + Chain, + MaskedCoupling, + TransformedDistribution, +) +from surjectors.conditioners import mlp_conditioner +from surjectors.util import make_alternating_binary_mask + +from sbijax import SNL +from sbijax.mcmc import sample_with_nuts + + +def prior_model_fns(): + p = distrax.Independent( + distrax.Uniform(jnp.full(5, -3.0), jnp.full(5, 3.0)), 1 + ) + return p.sample, p.log_prob + + +def likelihood_fn(theta, y): + mu = jnp.tile(theta[:2], 4) + s1, s2 = theta[2] ** 2, theta[3] ** 2 + corr = s1 * s2 * jnp.tanh(theta[4]) + cov = jnp.array([[s1**2, corr], [corr, s2**2]]) + cov = jsp.linalg.block_diag(*[cov for _ in range(4)]) + p = distrax.MultivariateNormalFullCovariance(mu, cov) + return p.log_prob(y) + + +def simulator_fn(seed, theta): + orig_shape = theta.shape + if theta.ndim == 2: + theta = theta[:, None, :] + us_key, noise_key = random.split(seed) + + def _unpack_params(ps): + m0 = ps[..., [0]] + m1 = ps[..., [1]] + s0 = ps[..., [2]] ** 2 + s1 = ps[..., [3]] ** 2 + r = np.tanh(ps[..., [4]]) + return m0, m1, s0, s1, r + + m0, m1, s0, s1, r = _unpack_params(theta) + us = distrax.Normal(0.0, 1.0).sample( + seed=us_key, sample_shape=(theta.shape[0], theta.shape[1], 4, 2) + ) + xs = jnp.empty_like(us) + xs = xs.at[:, :, :, 0].set(s0 * us[:, :, :, 0] + m0) + y = xs.at[:, :, :, 1].set( + s1 * (r * us[:, :, :, 0] + np.sqrt(1.0 - r**2) * us[:, :, :, 1]) + m1 + ) + if len(orig_shape) == 2: + y = y.reshape((*theta.shape[:1], 8)) + else: + y = y.reshape((*theta.shape[:2], 8)) + return y + + +def make_model(dim, use_surjectors): + def _bijector_fn(params): + means, log_scales = jnp.split(params, 2, -1) + return distrax.ScalarAffine(means, jnp.exp(log_scales)) + + def _conditional_fn(n_dim): + decoder_net = mlp_conditioner([32, 32, n_dim * 2]) + + def _fn(z): + params = decoder_net(z) + mu, log_scale = jnp.split(params, 2, -1) + return distrax.Independent( + distrax.Normal(mu, jnp.exp(log_scale)), 1 + ) + + return _fn + + def _flow(method, **kwargs): + layers = [] + n_dimension = dim + for i in range(5): + mask = make_alternating_binary_mask(n_dimension, i % 2 == 0) + if i == 2 and use_surjectors: + n_latent = 6 + layer = AffineMaskedCouplingInferenceFunnel( + n_latent, + _conditional_fn(n_dimension - n_latent), + mlp_conditioner([32, 32, n_dimension * 2]), + ) + n_dimension = n_latent + else: + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_dimension * 2]), + ) + layers.append(layer) + chain = Chain(layers) + + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(n_dimension), jnp.ones(n_dimension)), + reinterpreted_batch_ndims=1, + ) + td = TransformedDistribution(base_distribution, chain) + return td(method, **kwargs) + + td = hk.transform(_flow) + td = hk.without_apply_rng(td) + return td + + +def run(use_surjectors): + len_theta = 5 + # this is the thetas used in SNL + # thetas = jnp.array([-0.7, -2.9, -1.0, -0.9, 0.6]) + y_observed = jnp.array( + [ + [ + -0.9707123, + -2.9461224, + -0.4494722, + -3.4231849, + -0.13285634, + -3.364017, + -0.85367596, + -2.4271638, + ] + ] + ) + prior_sampler, prior_fn = prior_model_fns() + fns = (prior_sampler, prior_fn), simulator_fn + model = make_model(y_observed.shape[1], use_surjectors) + snl = SNL(fns, model) + optimizer = optax.adam(1e-3) + params, info = snl.fit( + random.PRNGKey(23), y_observed, optimizer, n_rounds=10 + ) + + snl_samples, _ = snl.sample_posterior(params, 20, 50000, 10000) + snl_samples = snl_samples.reshape(-1, len_theta) + + def log_density_fn(theta, y): + prior_lp = prior_fn(theta) + likelihood_lp = likelihood_fn(theta, y) + + lp = jnp.sum(prior_lp) + jnp.sum(likelihood_lp) + return lp + + log_density_partial = partial(log_density_fn, y=y_observed) + log_density = lambda x: log_density_partial(**x) + + rng_seq = hk.PRNGSequence(12) + nuts_samples = sample_with_nuts( + rng_seq, log_density, len_theta, 20, 50000, 10000 + ) + nuts_samples = nuts_samples.reshape(-1, len_theta) + + g = sns.PairGrid(pd.DataFrame(nuts_samples)) + g.map_upper(sns.scatterplot, color="black", marker=".", edgecolor=None, s=2) + g.map_diag(plt.hist, color="black") + for ax in g.axes.flatten(): + ax.set_xlim(-5, 5) + ax.set_ylim(-5, 5) + g.fig.set_figheight(5) + g.fig.set_figwidth(5) + plt.show() + + fig, axes = plt.subplots(len_theta, 2) + for i in range(len_theta): + sns.histplot(nuts_samples[:, i], color="darkgrey", ax=axes[i, 0]) + sns.histplot(snl_samples[:, i], color="darkblue", ax=axes[i, 1]) + axes[i, 0].set_title(rf"Sampled posterior $\theta_{i}$") + axes[i, 1].set_title(rf"Approximated posterior $\theta_{i}$") + for j in range(2): + axes[i, j].set_xlim(-5, 5) + sns.despine() + plt.tight_layout() + plt.show() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--use-surjectors", action="store_true", default=True) + args = parser.parse_args() + run(args.use_surjectors) diff --git a/sbijax/__init__.py b/sbijax/__init__.py index 5a235d6..8e3bb0f 100644 --- a/sbijax/__init__.py +++ b/sbijax/__init__.py @@ -2,9 +2,10 @@ sbijax: Simulation-based inference in JAX """ -__version__ = "0.0.9" +__version__ = "0.0.10" from sbijax.abc.rejection_abc import RejectionABC from sbijax.abc.smc_abc import SMCABC from sbijax.snl import SNL +from sbijax.snp import SNP diff --git a/sbijax/_sbi_base.py b/sbijax/_sbi_base.py index 29a7fce..4bd20fa 100644 --- a/sbijax/_sbi_base.py +++ b/sbijax/_sbi_base.py @@ -1,11 +1,16 @@ +import abc +from typing import Optional + import chex import haiku as hk from jax import numpy as jnp from jax import random +from sbijax.generator import named_dataset + -# pylint: disable=too-many-instance-attributes -class SBI: +# pylint: disable=too-many-instance-attributes,unused-argument +class SBI(abc.ABC): """ SBI base class """ @@ -14,10 +19,43 @@ def __init__(self, model_fns): self.prior_sampler_fn, self.prior_log_density_fn = model_fns[0] self.simulator_fn = model_fns[1] self._len_theta = len(self.prior_sampler_fn(seed=random.PRNGKey(0))) - self.observed: chex.Array + + self._observed: chex.Array self._rng_seq: hk.PRNGSequence + self._data: Optional[named_dataset] = None + + @property + def observed(self): + """Get the observation to condition on""" + return self._observed + + @observed.setter + def observed(self, observed): + """Set the observation to condition on""" + self._observed = jnp.atleast_2d(observed) + + @property + def data(self): + """Get the data set""" + return self._data + + @data.setter + def data(self, data): + """Set the data set""" + if not isinstance(data, named_dataset): + raise TypeError("data is not of type 'named_dataset'") + self._data = data - def fit(self, rng_key, observed): + @property + def rng_seq(self): + """Rng sequence""" + return self._rng_seq + + @rng_seq.setter + def rng_seq(self, rng_seq): + self._rng_seq = rng_seq + + def fit(self, rng_key, observed, **kwargs): """ Fit the model @@ -30,8 +68,16 @@ def fit(self, rng_key, observed): number of samples """ - self._rng_seq = hk.PRNGSequence(rng_key) - self.observed = jnp.atleast_2d(observed) + self.rng_seq = hk.PRNGSequence(rng_key) + self.observed = observed + @abc.abstractmethod def sample_posterior(self, **kwargs): - """Sample from the posterior""" + """ + Sample from the posterior distribution + + Parameters + ---------- + kwargs + keyword arguments + """ diff --git a/sbijax/_sne_base.py b/sbijax/_sne_base.py new file mode 100644 index 0000000..3099e93 --- /dev/null +++ b/sbijax/_sne_base.py @@ -0,0 +1,84 @@ +from abc import ABC +from typing import Iterable + +from jax import numpy as jnp + +from sbijax import generator +from sbijax._sbi_base import SBI +from sbijax.generator import named_dataset + + +# pylint: disable=too-many-arguments,unused-argument +# pylint: disable=too-many-function-args,arguments-differ +class SNE(SBI, ABC): + """ + Sequential neural estimation + """ + + def __init__(self, model_fns, density_estimator): + super().__init__(model_fns) + self.model = density_estimator + self.n_total_simulations = 0 + self._train_iter: Iterable + self._val_iter: Iterable + + def simulate_new_data_and_append(self, params, n_simulations): + """ + Simulate novel data-parameters pairs and append to the + existing data set. + + Parameters + ---------- + params: pytree + parameter set of the neural network + n_simulations: int + number of data-parameter pairs to draw + + Returns + ------- + Returns the data set. + """ + + self.data = self._simulate_new_data_and_append( + params, self.data, n_simulations + ) + return self.data + + def _simulate_new_data_and_append( + self, + params, + D, + n_simulations_per_round, + **kwargs, + ): + if D is None: + diagnostics = None + self.n_total_simulations += n_simulations_per_round + new_thetas = self.prior_sampler_fn( + seed=next(self._rng_seq), + sample_shape=(n_simulations_per_round,), + ) + else: + new_thetas, diagnostics = self.sample_posterior( + params, n_simulations_per_round, **kwargs + ) + + new_obs = self.simulator_fn(seed=next(self._rng_seq), theta=new_thetas) + new_data = named_dataset(new_obs, new_thetas) + if D is None: + d_new = new_data + else: + d_new = named_dataset( + *[jnp.vstack([a, b]) for a, b in zip(D, new_data)] + ) + return d_new, diagnostics + + def as_iterators(self, D, batch_size, percentage_data_as_validation_set): + """Convert the data set to an iterable for training""" + return generator.as_batch_iterators( + next(self._rng_seq), + D, + batch_size, + 1.0 - percentage_data_as_validation_set, + True, + ) diff --git a/sbijax/abc/rejection_abc.py b/sbijax/abc/rejection_abc.py index 5f2b510..bda54de 100644 --- a/sbijax/abc/rejection_abc.py +++ b/sbijax/abc/rejection_abc.py @@ -19,7 +19,7 @@ def __init__(self, model_fns, summary_fn, kernel_fn): self.summary_fn = summary_fn self.summarized_observed: chex.Array - def fit(self, rng_key, observed): + def fit(self, rng_key, observed, **kwargs): super().fit(rng_key, observed) self.summarized_observed = self.summary_fn(self.observed) diff --git a/sbijax/abc/smc_abc.py b/sbijax/abc/smc_abc.py index ac4f601..08da1d8 100644 --- a/sbijax/abc/smc_abc.py +++ b/sbijax/abc/smc_abc.py @@ -11,6 +11,7 @@ from sbijax._sbi_base import SBI +# pylint: disable=arguments-differ,too-many-function-args class SMCABC(SBI): """ Sisson et al. - Handbook of approximate Bayesian computation diff --git a/sbijax/generator.py b/sbijax/generator.py index f471668..c6cb7a8 100644 --- a/sbijax/generator.py +++ b/sbijax/generator.py @@ -5,7 +5,7 @@ from jax import numpy as jnp from jax import random -named_dataset = namedtuple("named_dataset", "y x") +named_dataset = namedtuple("named_dataset", "y theta") # pylint: disable=missing-class-docstring,too-few-public-methods diff --git a/sbijax/mcmc/__init__.py b/sbijax/mcmc/__init__.py index 05d55af..b83d7e2 100644 --- a/sbijax/mcmc/__init__.py +++ b/sbijax/mcmc/__init__.py @@ -1,2 +1,3 @@ from sbijax.mcmc.nuts import sample_with_nuts from sbijax.mcmc.sample import mcmc_diagnostics +from sbijax.mcmc.slice import sample_with_slice diff --git a/sbijax/snl.py b/sbijax/snl.py index 91f6372..e893b80 100644 --- a/sbijax/snl.py +++ b/sbijax/snl.py @@ -1,44 +1,25 @@ from collections import namedtuple from functools import partial -from typing import Iterable -import chex -import haiku as hk import jax import numpy as np import optax from absl import logging from flax.training.early_stopping import EarlyStopping from jax import numpy as jnp -from jax import random -from sbijax import generator -from sbijax._sbi_base import SBI -from sbijax.generator import named_dataset -from sbijax.mcmc import sample_with_nuts -from sbijax.mcmc.sample import mcmc_diagnostics +from sbijax._sne_base import SNE +from sbijax.mcmc import mcmc_diagnostics, sample_with_nuts, sample_with_slice -# pylint: disable=too-many-arguments -from sbijax.mcmc.slice import sample_with_slice - -class SNL(SBI): +# pylint: disable=too-many-arguments,unused-argument +class SNL(SNE): """ Sequential neural likelihood From the Papamakarios paper """ - def __init__(self, model_fns, density_estimator): - super().__init__(model_fns) - self.model = density_estimator - self._len_theta = len(self.prior_sampler_fn(seed=random.PRNGKey(0))) - - self.observed: chex.Array - self._rng_seq: hk.PRNGSequence - self._train_iter: Iterable - self._val_iter: Iterable - # pylint: disable=arguments-differ,too-many-locals def fit( self, @@ -120,12 +101,8 @@ def fit( ) for _ in range(n_rounds): D, diagnostics = simulator_fn(params, D, **kwargs) - self._train_iter, self._val_iter = generator.as_batch_iterators( - next(self._rng_seq), - D, - batch_size, - 1.0 - percentage_data_as_validation_set, - True, + self._train_iter, self._val_iter = self.as_iterators( + D, batch_size, percentage_data_as_validation_set ) params, losses = self._fit_model_single_round( optimizer=optimizer, @@ -168,20 +145,62 @@ def sample_posterior(self, params, n_chains, n_samples, n_warmup, **kwargs): (n_samples \times p) """ - return self._simulate_from_amortized_posterior( - params, n_chains, n_samples, n_warmup, **kwargs + part = partial( + self.model.apply, params=params, method="log_prob", y=self.observed ) + def _log_likelihood_fn(theta): + theta = jnp.tile(theta, [self.observed.shape[0], 1]) + return part(x=theta) + + def _joint_logdensity_fn(theta): + lp_prior = self.prior_log_density_fn(theta) + lp = _log_likelihood_fn(theta) + return jnp.sum(lp) + jnp.sum(lp_prior) + + if "sampler" in kwargs and kwargs["sampler"] == "slice": + + def lp__(theta): + return jax.vmap(_joint_logdensity_fn)(theta) + + kwargs.pop("sampler", None) + samples = sample_with_slice( + self._rng_seq, + lp__, + n_chains, + n_samples, + n_warmup, + self.prior_sampler_fn, + **kwargs, + ) + else: + + def lp__(theta): + return _joint_logdensity_fn(**theta) + + samples = sample_with_nuts( + self._rng_seq, + lp__, + self._len_theta, + n_chains, + n_samples, + n_warmup, + ) + diagnostics = mcmc_diagnostics(samples) + return samples, diagnostics + def _fit_model_single_round( self, optimizer, max_n_iter, n_early_stopping_patience ): - params = self._init_params(next(self._rng_seq), self._train_iter(0)) + params = self._init_params(next(self._rng_seq), **self._train_iter(0)) state = optimizer.init(params) @jax.jit def step(params, state, **batch): def loss_fn(params): - lp = self.model.apply(params, method="log_prob", **batch) + lp = self.model.apply( + params, method="log_prob", y=batch["y"], x=batch["theta"] + ) return -jnp.sum(lp) loss, grads = jax.value_and_grad(loss_fn)(params) @@ -211,7 +230,9 @@ def loss_fn(params): def _validation_loss(self, params): def _loss_fn(**batch): - lp = self.model.apply(params, method="log_prob", **batch) + lp = self.model.apply( + params, method="log_prob", y=batch["y"], x=batch["theta"] + ) return -jnp.sum(lp) losses = jnp.array( @@ -222,87 +243,8 @@ def _loss_fn(**batch): ) return jnp.sum(losses) - def _init_params(self, rng_key, init_data): - params = self.model.init(rng_key, method="log_prob", **init_data) - return params - - def _simulate_new_data_and_append( - self, - params, - D, - n_simulations_per_round, - n_chains, - n_samples, - n_warmup, - **kwargs, - ): - if D is None: - diagnostics = None - new_thetas = self.prior_sampler_fn( - seed=next(self._rng_seq), - sample_shape=(n_simulations_per_round,), - ) - else: - new_thetas, diagnostics = self._simulate_from_amortized_posterior( - params, n_chains, n_samples, n_warmup, **kwargs - ) - new_thetas = new_thetas.reshape(-1, self._len_theta) - new_thetas = random.permutation(next(self._rng_seq), new_thetas) - new_thetas = new_thetas[:n_simulations_per_round, :] - - new_obs = self.simulator_fn(seed=next(self._rng_seq), theta=new_thetas) - new_data = named_dataset(new_obs, new_thetas) - if D is None: - d_new = new_data - else: - d_new = named_dataset( - *[jnp.vstack([a, b]) for a, b in zip(D, new_data)] - ) - return d_new, diagnostics - - def _simulate_from_amortized_posterior( - self, params, n_chains, n_samples, n_warmup, **kwargs - ): - part = partial( - self.model.apply, params=params, method="log_prob", y=self.observed + def _init_params(self, rng_key, **init_data): + params = self.model.init( + rng_key, method="log_prob", y=init_data["y"], x=init_data["theta"] ) - - def _log_likelihood_fn(theta): - theta = jnp.tile(theta, [self.observed.shape[0], 1]) - return part(x=theta) - - def _joint_logdensity_fn(theta): - lp_prior = self.prior_log_density_fn(theta) - lp = _log_likelihood_fn(theta) - return jnp.sum(lp) + jnp.sum(lp_prior) - - if "sampler" in kwargs and kwargs["sampler"] == "slice": - - def lp__(theta): - return jax.vmap(_joint_logdensity_fn)(theta) - - kwargs.pop("sampler", None) - samples = sample_with_slice( - self._rng_seq, - lp__, - n_chains, - n_samples, - n_warmup, - self.prior_sampler_fn, - **kwargs, - ) - else: - - def lp__(theta): - return _joint_logdensity_fn(**theta) - - samples = sample_with_nuts( - self._rng_seq, - lp__, - self._len_theta, - n_chains, - n_samples, - n_warmup, - ) - diagnostics = mcmc_diagnostics(samples) - return samples, diagnostics + return params diff --git a/sbijax/snp.py b/sbijax/snp.py new file mode 100644 index 0000000..801ec56 --- /dev/null +++ b/sbijax/snp.py @@ -0,0 +1,266 @@ +from collections import namedtuple +from functools import partial + +import jax +import numpy as np +import optax +from absl import logging +from flax.training.early_stopping import EarlyStopping +from jax import numpy as jnp +from jax import random +from jax import scipy as jsp + +from sbijax._sne_base import SNE + + +# pylint: disable=too-many-arguments,unused-argument +class SNP(SNE): + """ + Sequential neural posterior estimation + + From the Greenberg paper + """ + + # pylint: disable=arguments-differ,too-many-locals + def fit( + self, + rng_key, + observed, + optimizer, + n_rounds=10, + n_simulations_per_round=1000, + n_atoms=10, + max_n_iter=1000, + batch_size=128, + percentage_data_as_validation_set=0.05, + n_early_stopping_patience=10, + **kwargs, + ): + """ + Fit an SNPE model + + Parameters + ---------- + rng_seq: hk.PRNGSequence + a hk.PRNGSequence + observed: chex.Array + (n \times p)-dimensional array of observations, where `n` is the n + number of samples + optimizer: optax.Optimizer + an optax optimizer object + n_rounds: int + number of rounds to optimize + n_simulations_per_round: int + number of data simulations per round + n_atoms : int + number of atoms to approximate the proposal posterior + max_n_iter: + maximal number of training iterations per round + batch_size: int + batch size used for training the model + percentage_data_as_validation_set: + percentage of the simulated data that is used for valitation and + early stopping + n_early_stopping_patience: int + number of iterations of no improvement of training the flow + before stopping optimisation + + Returns + ------- + Tuple[pytree, Tuple] + returns a tuple of parameters and a tuple of the training + information + """ + + super().fit(rng_key, observed) + + simulator_fn = partial( + self._simulate_new_data_and_append, + n_simulations_per_round=n_simulations_per_round, + ) + D, params, all_losses, all_params = None, None, [], [] + for i_round in range(n_rounds): + D, _ = simulator_fn(params, D, **kwargs) + self._train_iter, self._val_iter = self.as_iterators( + D, batch_size, percentage_data_as_validation_set + ) + params, losses = self._fit_model_single_round( + optimizer=optimizer, + max_n_iter=max_n_iter, + n_early_stopping_patience=n_early_stopping_patience, + n_round=i_round, + n_atoms=n_atoms, + ) + all_params.append(params.copy()) + all_losses.append(losses) + + snp_info = namedtuple("snl_info", "params losses") + return params, snp_info(all_params, all_losses) + + def sample_posterior(self, params, n_samples, **kwargs): + """ + Sample from the approximate posterior + + Parameters + ---------- + params: pytree + a pytree of parameter for the model + n_samples: int + number of samples per chain + + Returns + ------- + chex.Array + an array of samples from the posterior distribution of dimension + (n_samples \times p) + """ + thetas = None + n_curr = n_samples + n_total_simulations_round = 0 + while n_curr > 0: + n_sim = jnp.maximum(100, n_curr) + n_total_simulations_round += n_sim + proposal = self.model.apply( + params, + next(self.rng_seq), + method="sample", + sample_shape=(n_sim,), + x=jnp.tile(self.observed, [n_sim, 1]), + ) + proposal_probs = self.prior_log_density_fn(proposal) + proposal_accepted = proposal[jnp.isfinite(proposal_probs)] + if thetas is None: + thetas = proposal_accepted + else: + thetas = jnp.vstack([thetas, proposal_accepted]) + n_curr -= proposal_accepted.shape[0] + self.n_total_simulations += n_total_simulations_round + return thetas[:n_samples], thetas.shape[0] / n_total_simulations_round + + def _fit_model_single_round( + self, optimizer, max_n_iter, n_early_stopping_patience, n_round, n_atoms + ): + params = self._init_params(next(self._rng_seq), **self._train_iter(0)) + state = optimizer.init(params) + + if n_round == 0: + + def loss_fn(params, rng, **batch): + lp = self.model.apply( + params, + None, + method="log_prob", + y=batch["theta"], + x=batch["y"], + ) + return -jnp.sum(lp) + + else: + + def loss_fn(params, rng, **batch): + lp = self._proposal_posterior_log_prob( + params, + rng, + n_atoms, + theta=batch["theta"], + y=batch["y"], + ) + return -jnp.sum(lp) + + @jax.jit + def step(params, rng, state, **batch): + loss, grads = jax.value_and_grad(loss_fn)(params, rng, **batch) + updates, new_state = optimizer.update(grads, state, params) + new_params = optax.apply_updates(params, updates) + return loss, new_params, new_state + + losses = np.zeros([max_n_iter, 2]) + early_stop = EarlyStopping(1e-3, n_early_stopping_patience) + logging.info("training model") + for i in range(max_n_iter): + train_loss = 0.0 + for j in range(self._train_iter.num_batches): + batch = self._train_iter(j) + batch_loss, params, state = step( + params, next(self.rng_seq), state, **batch + ) + train_loss += batch_loss + validation_loss = self._validation_loss( + params, next(self.rng_seq), n_round, n_atoms + ) + losses[i] = jnp.array([train_loss, validation_loss]) + + _, early_stop = early_stop.update(validation_loss) + if early_stop.should_stop: + logging.info("early stopping criterion found") + break + + losses = jnp.vstack(losses)[:i, :] + return params, losses + + def _init_params(self, rng_key, **init_data): + params = self.model.init( + rng_key, method="log_prob", y=init_data["theta"], x=init_data["y"] + ) + return params + + def _proposal_posterior_log_prob(self, params, rng, n_atoms, theta, y): + n = theta.shape[0] + n_atoms = np.maximum(2, np.minimum(n_atoms, n)) + repeated_y = jnp.repeat(y, n_atoms, axis=0) + probs = jnp.ones((n, n)) * (1 - jnp.eye(n)) / (n - 1) + + choice = partial( + random.choice, a=jnp.arange(n), replace=False, shape=(n_atoms - 1,) + ) + sample_keys = random.split(rng, probs.shape[0]) + choices = jax.vmap(lambda key, prob: choice(key, p=prob))( + sample_keys, probs + ) + contrasting_theta = theta[choices] + + atomic_theta = jnp.concatenate( + (theta[:, None, :], contrasting_theta), axis=1 + ) + atomic_theta = atomic_theta.reshape(n * n_atoms, -1) + + log_prob_posterior = self.model.apply( + params, None, method="log_prob", y=atomic_theta, x=repeated_y + ) + log_prob_posterior = log_prob_posterior.reshape(n, n_atoms) + log_prob_prior = self.prior_log_density_fn(atomic_theta) + log_prob_prior = log_prob_prior.reshape(n, n_atoms) + + unnormalized_log_prob = log_prob_posterior - log_prob_prior + log_prob_proposal_posterior = unnormalized_log_prob[ + :, 0 + ] - jsp.special.logsumexp(unnormalized_log_prob, axis=-1) + + return log_prob_proposal_posterior + + def _validation_loss(self, params, seed, n_round, n_atoms): + if n_round == 0: + + def loss_fn(rng, **batch): + lp = self.model.apply( + params, + None, + method="log_prob", + y=batch["theta"], + x=batch["y"], + ) + return -jnp.sum(lp) + + else: + + def loss_fn(rng, **batch): + lp = self._proposal_posterior_log_prob( + params, rng, n_atoms, batch["theta"], batch["y"] + ) + return -jnp.sum(lp) + + loss = 0 + for j in range(self._val_iter.num_batches): + rng, seed = random.split(seed) + loss += jax.jit(loss_fn)(rng, **self._val_iter(j)) + return loss