From 41dd5b007d2851cf784a00484fe15fa8e7faf27f Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Wed, 28 Feb 2024 22:04:30 +0100 Subject: [PATCH 1/6] Impl consistency model posterior estimation --- README.md | 1 + docs/sbijax.rst | 7 + examples/bivariate_gaussian_cfmpe.py | 85 ++++++++ sbijax/__init__.py | 5 +- sbijax/_src/nn/consistency_model.py | 229 +++++++++++++++++++++ sbijax/_src/scmpe.py | 294 +++++++++++++++++++++++++++ sbijax/_src/{fmpe.py => sfmpe.py} | 0 sbijax/nn/__init__.py | 4 + 8 files changed, 623 insertions(+), 2 deletions(-) create mode 100644 examples/bivariate_gaussian_cfmpe.py create mode 100644 sbijax/_src/nn/consistency_model.py create mode 100644 sbijax/_src/scmpe.py rename sbijax/_src/{fmpe.py => sfmpe.py} (100%) diff --git a/README.md b/README.md index 7aac6bb..6ca94a5 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,7 @@ pip install git+https://github.com/dirmeier/sbijax@ ## Acknowledgements +> [!NOTE] > 📝 The API of the package is heavily inspired by the excellent Pytorch-based [`sbi`](https://github.com/sbi-dev/sbi) package which is substantially more feature-complete and user-friendly, and better documented. diff --git a/docs/sbijax.rst b/docs/sbijax.rst index 332fa0c..5ca2633 100644 --- a/docs/sbijax.rst +++ b/docs/sbijax.rst @@ -16,6 +16,7 @@ Methods SNP SNR SFMPE + SCMPE SNASS SNASSS @@ -46,6 +47,12 @@ SNR SFMPE ~~~~~ +.. autoclass:: SFMPE + :members: fit, simulate_data_and_possibly_append, sample_posterior + +SCMPE +~~~~~ + .. autoclass:: SFMPE :members: fit, simulate_data_and_possibly_append, sample_posterior diff --git a/examples/bivariate_gaussian_cfmpe.py b/examples/bivariate_gaussian_cfmpe.py new file mode 100644 index 0000000..0c1570c --- /dev/null +++ b/examples/bivariate_gaussian_cfmpe.py @@ -0,0 +1,85 @@ +""" +Example using consistency model posterior estimation on a bivariate Gaussian +""" + +import distrax +import haiku as hk +import matplotlib.pyplot as plt +import optax +import seaborn as sns +from jax import numpy as jnp +from jax import random as jr + +from sbijax import SCMPE +from sbijax.nn import ConsistencyModel + + +def prior_model_fns(): + p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1) + return p.sample, p.log_prob + + +def simulator_fn(seed, theta): + p = distrax.Normal(jnp.zeros_like(theta), 1.0) + y = theta + p.sample(seed=seed) + return y + + +def make_model(dim): + @hk.transform + def _mlp(method, **kwargs): + def _c_skip(time): + return 1 / ((time - 0.001) ** 2 + 1) + + def _c_out(time): + return 1.0 * (time - 0.001) / jnp.sqrt(1 + time ** 2) + def _nn(theta, time, context, **kwargs): + ins = jnp.concatenate([theta, time, context], axis=-1) + outs = hk.nets.MLP([64, 64, dim])(ins) + out_skip = _c_skip(time) * theta + _c_out(time) * outs + return out_skip + + cm = ConsistencyModel(dim, _nn) + return cm(method, **kwargs) + + return _mlp + + +def run(): + y_observed = jnp.array([2.0, -2.0]) + + prior_simulator_fn, prior_logdensity_fn = prior_model_fns() + fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn + + estim = SCMPE(fns, make_model(2)) + optimizer = optax.adam(1e-3) + + data, params = None, {} + for i in range(2): + data, _ = estim.simulate_data_and_possibly_append( + jr.fold_in(jr.PRNGKey(1), i), + params=params, + observable=y_observed, + data=data, + ) + params, info = estim.fit( + jr.fold_in(jr.PRNGKey(2), i), + data=data, + optimizer=optimizer, + ) + + + rng_key = jr.PRNGKey(23) + post_samples, _ = estim.sample_posterior(rng_key, params, y_observed) + print(post_samples) + fig, axes = plt.subplots(2) + for i, ax in enumerate(axes): + sns.histplot(post_samples[:, i], color="darkblue", ax=ax) + ax.set_xlim([-3.0, 3.0]) + sns.despine() + plt.tight_layout() + plt.show() + + +if __name__ == "__main__": + run() diff --git a/sbijax/__init__.py b/sbijax/__init__.py index 646b951..75ada41 100644 --- a/sbijax/__init__.py +++ b/sbijax/__init__.py @@ -2,11 +2,12 @@ sbijax: Simulation-based inference in JAX """ -__version__ = "0.1.9" +__version__ = "0.2.0" from sbijax._src.abc.smc_abc import SMCABC -from sbijax._src.fmpe import SFMPE +from sbijax._src.scmpe import SCMPE +from sbijax._src.sfmpe import SFMPE from sbijax._src.snass import SNASS from sbijax._src.snasss import SNASSS from sbijax._src.snl import SNL diff --git a/sbijax/_src/nn/consistency_model.py b/sbijax/_src/nn/consistency_model.py new file mode 100644 index 0000000..02259e3 --- /dev/null +++ b/sbijax/_src/nn/consistency_model.py @@ -0,0 +1,229 @@ +from typing import Callable + +import distrax +import haiku as hk +import jax +from jax import numpy as jnp +from jax.nn import glu +from scipy import integrate + +__all__ = ["ConsistencyModel", "make_consistency_model"] + +from sbijax._src.nn.make_resnet import _Resnet + + +class ConsistencyModel(hk.Module): + """Conditional continuous normalizing flow. + + Args: + n_dimension: the dimensionality of the modelled space + transform: a haiku module. The transform is a callable that has to + take as input arguments named 'theta', 'time', 'context' and + **kwargs. Theta, time and context are two-dimensional arrays + with the same batch dimensions. + """ + + def __init__(self, n_dimension: int, transform: Callable, t_max=50): + """Conditional continuous normalizing flow. + + Args: + n_dimension: the dimensionality of the modelled space + transform: a haiku module. The transform is a callable that has to + take as input arguments named 'theta', 'time', 'context' and + **kwargs. Theta, time and context are two-dimensional arrays + with the same batch dimensions. + """ + super().__init__() + self._n_dimension = n_dimension + self._network = transform + self._t_max = t_max + self._base_distribution = distrax.Normal(jnp.zeros(n_dimension), 1.0) + + def __call__(self, method, **kwargs): + """Aplpy the flow. + + Args: + method (str): method to call + + Keyword Args: + keyword arguments for the called method: + """ + return getattr(self, method)(**kwargs) + + def sample(self, context): + """Sample from the pushforward. + + Args: + context: array of conditioning variables + """ + theta_0 = self._base_distribution.sample( + seed=hk.next_rng_key(), sample_shape=(context.shape[0],) + ) + y_hat = self.vector_field(theta_0, self._t_max, context) + return y_hat + + def vector_field(self, theta, time, context, **kwargs): + """Compute the vector field. + + Args: + theta: array of parameters + time: time variables + context: array of conditioning variables + + Keyword Args: + keyword arguments that aer passed tothe neural network + """ + time = jnp.full((theta.shape[0], 1), time) + return self._network(theta=theta, time=time, context=context, **kwargs) + + +# pylint: disable=too-many-arguments +class _ResnetBlock(hk.Module): + """A block for a 1d residual network.""" + + def __init__( + self, + hidden_size: int, + activation: Callable = jax.nn.relu, + dropout_rate: float = 0.2, + do_batch_norm: bool = False, + batch_norm_decay: float = 0.1, + ): + super().__init__() + self.hidden_size = hidden_size + self.activation = activation + self.do_batch_norm = do_batch_norm + self.dropout_rate = dropout_rate + self.batch_norm_decay = batch_norm_decay + + def __call__(self, inputs, context, is_training=False): + outputs = inputs + if self.do_batch_norm: + outputs = hk.BatchNorm(True, True, self.batch_norm_decay)( + outputs, is_training=is_training + ) + outputs = hk.Linear(self.hidden_size)(outputs) + outputs = self.activation(outputs) + if is_training: + outputs = hk.dropout( + rng=hk.next_rng_key(), rate=self.dropout_rate, x=outputs + ) + outputs = hk.Linear(self.hidden_size)(outputs) + context_proj = hk.Linear(inputs.shape[-1])(context) + outputs = glu(jnp.concatenate([outputs, context_proj], axis=-1)) + return outputs + inputs + + +# pylint: disable=too-many-arguments +class _CMResnet(hk.Module): + """A simplified 1-d residual network.""" + + def __init__( + self, + n_layers: int, + n_dimension: int, + hidden_size: int, + activation: Callable = jax.nn.relu, + dropout_rate: float = 0.0, + do_batch_norm: bool = False, + batch_norm_decay: float = 0.1, + eps: float = 0.001, + sigma_data:float = 1.0 + ): + super().__init__() + self.n_layers = n_layers + self.n_dimension = n_dimension + self.hidden_size = hidden_size + self.activation = activation + self.do_batch_norm = do_batch_norm + self.dropout_rate = dropout_rate + self.batch_norm_decay = batch_norm_decay + self.sigma_data = sigma_data + self.var_data = self.sigma_data**2 + self.eps = eps + + def __call__(self, theta, time, context, is_training=False, **kwargs): + outputs = context + t_theta_embedding = jnp.concatenate( + [ + hk.Linear(self.n_dimension)(theta), + hk.Linear(self.n_dimension)(time), + ], + axis=-1, + ) + outputs = hk.Linear(self.hidden_size)(outputs) + outputs = self.activation(outputs) + for _ in range(self.n_layers): + outputs = _ResnetBlock( + hidden_size=self.hidden_size, + activation=self.activation, + dropout_rate=self.dropout_rate, + do_batch_norm=self.do_batch_norm, + batch_norm_decay=self.batch_norm_decay, + )(outputs, context=t_theta_embedding, is_training=is_training) + outputs = self.activation(outputs) + outputs = hk.Linear(self.n_dimension)(outputs) + + # TODO(simon): how is sigma_data chosen automatically? + # in the meantime set it to 1 and use batch norm before + #outputs = hk.BatchNorm(True, True, self.batch_norm_decay)(outputs, is_training=is_training) + out_skip = self._c_skip(time) * theta + self._c_out(time) * outputs + return out_skip + + def _c_skip(self, time): + return self.var_data / ((time - self.eps) ** 2 + self.var_data) + + def _c_out(self, time): + return ( + self.sigma_data + * (time - self.eps) + / jnp.sqrt(self.var_data + time**2) + ) + + +def make_consistency_model( + n_dimension: int, + n_layers: int = 2, + hidden_size: int = 64, + activation: Callable = jax.nn.tanh, + dropout_rate: float = 0.2, + do_batch_norm: bool = False, + batch_norm_decay: float = 0.2, + t_max: float=50, + epsilon=0.001, + sigma_data:float=1.0 +): + """Create a conditional continuous normalizing flow. + + The CCNF uses a residual network as transformer which is created + automatically. + + Args: + n_dimension: dimensionality of modelled space + n_layers: number of resnet blocks + hidden_size: sizes of hidden layers for each resnet block + activation: a jax activation function + dropout_rate: dropout rate to use in resnet blocks + do_batch_norm: use batch normalization or not + batch_norm_decay: decay rate of EMA in batch norm layer + Returns: + returns a conditional continuous normalizing flow + """ + + @hk.transform + def _cm(method, **kwargs): + nn = _CMResnet( + n_layers=n_layers, + n_dimension=n_dimension, + hidden_size=hidden_size, + activation=activation, + do_batch_norm=do_batch_norm, + dropout_rate=dropout_rate, + batch_norm_decay=batch_norm_decay, + eps=epsilon, + sigma_data=sigma_data, + ) + cm = ConsistencyModel(n_dimension, nn, t_max=t_max) + return cm(method, **kwargs) + + return _cm diff --git a/sbijax/_src/scmpe.py b/sbijax/_src/scmpe.py new file mode 100644 index 0000000..e5fed87 --- /dev/null +++ b/sbijax/_src/scmpe.py @@ -0,0 +1,294 @@ +from functools import partial + +import jax +import numpy as np +import optax +from absl import logging +from jax import numpy as jnp +from jax import random as jr +from tqdm import tqdm + +from sbijax._src._sne_base import SNE +from sbijax._src.util.early_stopping import EarlyStopping + + +def _alpha_t(time): + return 1 / (_time_schedule(time + 1) - _time_schedule(time)) + + +def _time_schedule(n, rho=7, eps=0.001, T_max=50, N=1000): + left = eps ** (1 / rho) + right = T_max ** (1 / rho) - eps ** (1 / rho) + right = (n - 1) / (N - 1) * right + return (left + right) ** rho + + +# pylint: disable=too-many-locals +def _consistency_loss( + params, ema_params, rng_key, apply_fn, is_training=False, **batch +): + theta = batch["theta"] + + t_key, rng_key = jr.split(rng_key) + time_idx = jr.randint(t_key, shape=(theta.shape[0],), minval=1, maxval=1000 - 1) + tn = _time_schedule(time_idx).reshape(-1, 1) + tnp1 = _time_schedule(time_idx + 1).reshape(-1, 1) + + noise_key, rng_key = jr.split(rng_key) + noise = jr.normal(noise_key, shape=(*theta.shape,)) + + train_rng, rng_key = jr.split(rng_key) + fnp1 = apply_fn( + params, + train_rng, + method="vector_field", + theta=theta + tnp1 * noise, + time=tnp1, + context=batch["y"], + is_training=is_training, + ) + fn = apply_fn( + ema_params, + train_rng, + method="vector_field", + theta=theta + tn * noise, + time=tn, + context=batch["y"], + is_training=is_training, + ) + mse = jnp.mean(jnp.square(fnp1 - fn), axis=1) + loss = _alpha_t(time_idx) * mse + return jnp.mean(loss) + + +# pylint: disable=too-many-arguments,unused-argument,useless-parent-delegation +class SCMPE(SNE): + r"""Sequential consistency model posterior estimation. + + Implements a sequential version of the CMPE algorithm introduced in [1]_. + For all rounds $r > 1$ parameter samples + :math:`\theta \sim \hat{p}^r(\theta)` are drawn from + the approximate posterior instead of the prior when computing consistency + loss + + Args: + model_fns: a tuple of tuples. The first element is a tuple that + consists of functions to sample and evaluate the + log-probability of a data point. The second element is a + simulator function. + network: a neural network + + Examples: + >>> import distrax + >>> from sbijax import SCMPE + >>> from sbijax.nn import make_consistency_model + >>> + >>> prior = distrax.Normal(0.0, 1.0) + >>> s = lambda seed, theta: distrax.Normal(theta, 1.0).sample(seed=seed) + >>> fns = (prior.sample, prior.log_prob), s + >>> net = make_consistency_model(1) + >>> + >>> estim = SCMPE(fns, net) + + References: + .. [1] Wildberger, Jonas, et al. "Flow Matching for Scalable + Simulation-Based Inference." Advances in Neural Information + Processing Systems, 2024. + """ + + def __init__(self, model_fns, network): + """Construct a FMPE object. + + Args: + model_fns: a tuple of tuples. The first element is a tuple that + consists of functions to sample and evaluate the + log-probability of a data point. The second element is a + simulator function. + network: network: a neural network + """ + super().__init__(model_fns, network) + + # pylint: disable=arguments-differ,too-many-locals + def fit( + self, + rng_key, + data, + *, + optimizer=optax.adam(0.0003), + n_iter=1000, + batch_size=100, + percentage_data_as_validation_set=0.1, + n_early_stopping_patience=10, + **kwargs, + ): + """Fit the model. + + Args: + rng_key: a jax random key + data: data set obtained from calling + `simulate_data_and_possibly_append` + optimizer: an optax optimizer object + n_iter: maximal number of training iterations per round + batch_size: batch size used for training the model + percentage_data_as_validation_set: percentage of the simulated + data that is used for validation and early stopping + n_early_stopping_patience: number of iterations of no improvement + of training the flow before stopping optimisation\ + + Returns: + a tuple of parameters and a tuple of the training information + """ + itr_key, rng_key = jr.split(rng_key) + train_iter, val_iter = self.as_iterators( + itr_key, data, batch_size, percentage_data_as_validation_set + ) + params, losses = self._fit_model_single_round( + seed=rng_key, + train_iter=train_iter, + val_iter=val_iter, + optimizer=optimizer, + n_iter=n_iter, + n_early_stopping_patience=n_early_stopping_patience, + ) + + return params, losses + + # pylint: disable=undefined-loop-variable + def _fit_model_single_round( + self, + seed, + train_iter, + val_iter, + optimizer, + n_iter, + n_early_stopping_patience, + ): + init_key, seed = jr.split(seed) + params = self._init_params(init_key, **next(iter(train_iter))) + ema_params = params.copy() + state = optimizer.init(params) + + loss_fn = jax.jit( + partial( + _consistency_loss, apply_fn=self.model.apply, is_training=False + ) + ) + + @jax.jit + def ema_update(params, avg_params): + return optax.incremental_update(params, avg_params, step_size=0.001) + + @jax.jit + def step(params, ema_params, rng, state, **batch): + loss, grads = jax.value_and_grad(loss_fn)( + params, ema_params, rng, **batch + ) + updates, new_state = optimizer.update(grads, state, params) + new_params = optax.apply_updates(params, updates) + new_ema_params = ema_update(new_params, ema_params) + return loss, new_params, new_ema_params, new_state + + losses = np.zeros([n_iter, 2]) + early_stop = EarlyStopping(1e-3, n_early_stopping_patience) + best_params, best_loss = None, np.inf + logging.info("training model") + for i in tqdm(range(n_iter)): + train_loss = 0.0 + rng_key = jr.fold_in(seed, i) + for batch in train_iter: + train_key, rng_key = jr.split(rng_key) + batch_loss, params, ema_params, state = step( + params, ema_params, train_key, state, **batch + ) + train_loss += batch_loss * ( + batch["y"].shape[0] / train_iter.num_samples + ) + val_key, rng_key = jr.split(rng_key) + validation_loss = self._validation_loss( + val_key, params, ema_params, val_iter + ) + losses[i] = jnp.array([train_loss, validation_loss]) + + _, early_stop = early_stop.update(validation_loss) + if early_stop.should_stop: + logging.info("early stopping criterion found") + break + if validation_loss < best_loss: + best_loss = validation_loss + best_params = params.copy() + + losses = jnp.vstack(losses)[: (i + 1), :] + return best_params, losses + + def _init_params(self, rng_key, **init_data): + times = jr.uniform(jr.PRNGKey(0), shape=(init_data["y"].shape[0], 1)) + params = self.model.init( + rng_key, + method="vector_field", + theta=init_data["theta"], + time=times, + context=init_data["y"], + is_training=False, + ) + return params + + def _validation_loss(self, rng_key, params, ema_params, val_iter): + loss_fn = jax.jit( + partial( + _consistency_loss, apply_fn=self.model.apply, is_training=False + ) + ) + + def body_fn(batch_key, **batch): + loss = loss_fn(params, ema_params, batch_key, **batch) + return loss * (batch["y"].shape[0] / val_iter.num_samples) + + loss = 0.0 + for batch in val_iter: + val_key, rng_key = jr.split(rng_key) + loss += body_fn(val_key, **batch) + return loss + + def sample_posterior( + self, rng_key, params, observable, *, n_samples=4_000, **kwargs + ): + r"""Sample from the approximate posterior. + + Args: + rng_key: a jax random key + params: a pytree of neural network parameters + observable: observation to condition on + n_samples: number of samples to draw + + Returns: + returns an array of samples from the posterior distribution of + dimension (n_samples \times p) + """ + observable = jnp.atleast_2d(observable) + + thetas = None + n_curr = n_samples + n_total_simulations_round = 0 + while n_curr > 0: + n_sim = jnp.minimum(200, jnp.maximum(200, n_curr)) + n_total_simulations_round += n_sim + sample_key, rng_key = jr.split(rng_key) + proposal = self.model.apply( + params, + sample_key, + method="sample", + context=jnp.tile(observable, [n_sim, 1]), + ) + proposal_probs = self.prior_log_density_fn(proposal) + proposal_accepted = proposal[jnp.isfinite(proposal_probs)] + if thetas is None: + thetas = proposal_accepted + else: + thetas = jnp.vstack([thetas, proposal_accepted]) + n_curr -= proposal_accepted.shape[0] + + self.n_total_simulations += n_total_simulations_round + return ( + thetas[:n_samples], + thetas.shape[0] / n_total_simulations_round, + ) diff --git a/sbijax/_src/fmpe.py b/sbijax/_src/sfmpe.py similarity index 100% rename from sbijax/_src/fmpe.py rename to sbijax/_src/sfmpe.py diff --git a/sbijax/nn/__init__.py b/sbijax/nn/__init__.py index 635b6f9..1972b41 100644 --- a/sbijax/nn/__init__.py +++ b/sbijax/nn/__init__.py @@ -1,5 +1,9 @@ """Neural network module.""" +from sbijax._src.nn.consistency_model import ( + ConsistencyModel, + make_consistency_model, +) from sbijax._src.nn.continuous_normalizing_flow import CCNF, make_ccnf from sbijax._src.nn.make_flows import ( make_affine_maf, From d7e16efa823acd467a3278a5ff15c18e0cb49ac7 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Wed, 28 Feb 2024 22:17:31 +0100 Subject: [PATCH 2/6] kinda works now --- sbijax/_src/scmpe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sbijax/_src/scmpe.py b/sbijax/_src/scmpe.py index e5fed87..8fa485c 100644 --- a/sbijax/_src/scmpe.py +++ b/sbijax/_src/scmpe.py @@ -176,7 +176,7 @@ def _fit_model_single_round( @jax.jit def ema_update(params, avg_params): - return optax.incremental_update(params, avg_params, step_size=0.001) + return optax.incremental_update(avg_params, params , step_size=0.01) @jax.jit def step(params, ema_params, rng, state, **batch): @@ -189,7 +189,7 @@ def step(params, ema_params, rng, state, **batch): return loss, new_params, new_ema_params, new_state losses = np.zeros([n_iter, 2]) - early_stop = EarlyStopping(1e-3, n_early_stopping_patience) + early_stop = EarlyStopping(1e-3, n_early_stopping_patience*2) best_params, best_loss = None, np.inf logging.info("training model") for i in tqdm(range(n_iter)): From 218a4b4a8e770c96f6c7c0d589dd72b526f5e76d Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Thu, 29 Feb 2024 11:19:38 +0100 Subject: [PATCH 3/6] fix some things and docu --- README.md | 1 + docs/index.rst | 1 + docs/sbijax.rst | 2 +- examples/bivariate_gaussian_cfmpe.py | 4 +- sbijax/_src/nn/consistency_model.py | 115 +++++++---------- sbijax/_src/scmpe.py | 186 ++++++++++----------------- sbijax/_src/sfmpe.py | 6 +- 7 files changed, 127 insertions(+), 188 deletions(-) diff --git a/README.md b/README.md index 6ca94a5..f14e859 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ - [Neural Approximate Sufficient Statistics](https://arxiv.org/abs/2010.10079) (`SNASS`) - [Neural Approximate Slice Sufficient Statistics](https://openreview.net/forum?id=jjzJ768iV1) (`SNASSS`) - [Flow matching posterior estimation](https://openreview.net/forum?id=jjzJ768iV1) (`SFMPE`) +- [Consistency model posterior estimation](https://arxiv.org/abs/2312.05440) (`SCMPE`) where the acronyms in parentheses denote the names of the methods in `sbijax`. diff --git a/docs/index.rst b/docs/index.rst index 4f3b482..b25e991 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -21,6 +21,7 @@ - `Neural Approximate Sufficient Statistics `_ (:code:`SNASS`) - `Neural Approximate Slice Sufficient Statistics `_ (:code:`SNASSS`) - `Flow matching posterior estimation `_ (:code:`SFMPE`) +- `Consistency model posterior estimation `_ (:code:`SCMPE`) .. caution:: diff --git a/docs/sbijax.rst b/docs/sbijax.rst index 5ca2633..a074306 100644 --- a/docs/sbijax.rst +++ b/docs/sbijax.rst @@ -53,7 +53,7 @@ SFMPE SCMPE ~~~~~ -.. autoclass:: SFMPE +.. autoclass:: SCMPE :members: fit, simulate_data_and_possibly_append, sample_posterior SNASS diff --git a/examples/bivariate_gaussian_cfmpe.py b/examples/bivariate_gaussian_cfmpe.py index 0c1570c..01404e4 100644 --- a/examples/bivariate_gaussian_cfmpe.py +++ b/examples/bivariate_gaussian_cfmpe.py @@ -32,7 +32,8 @@ def _c_skip(time): return 1 / ((time - 0.001) ** 2 + 1) def _c_out(time): - return 1.0 * (time - 0.001) / jnp.sqrt(1 + time ** 2) + return 1.0 * (time - 0.001) / jnp.sqrt(1 + time**2) + def _nn(theta, time, context, **kwargs): ins = jnp.concatenate([theta, time, context], axis=-1) outs = hk.nets.MLP([64, 64, dim])(ins) @@ -68,7 +69,6 @@ def run(): optimizer=optimizer, ) - rng_key = jr.PRNGKey(23) post_samples, _ = estim.sample_posterior(rng_key, params, y_observed) print(post_samples) diff --git a/sbijax/_src/nn/consistency_model.py b/sbijax/_src/nn/consistency_model.py index 02259e3..a871d2b 100644 --- a/sbijax/_src/nn/consistency_model.py +++ b/sbijax/_src/nn/consistency_model.py @@ -4,16 +4,14 @@ import haiku as hk import jax from jax import numpy as jnp -from jax.nn import glu -from scipy import integrate -__all__ = ["ConsistencyModel", "make_consistency_model"] +from sbijax._src.nn.continuous_normalizing_flow import _ResnetBlock -from sbijax._src.nn.make_resnet import _Resnet +__all__ = ["ConsistencyModel", "make_consistency_model"] class ConsistencyModel(hk.Module): - """Conditional continuous normalizing flow. + """A consistency model. Args: n_dimension: the dimensionality of the modelled space @@ -21,10 +19,18 @@ class ConsistencyModel(hk.Module): take as input arguments named 'theta', 'time', 'context' and **kwargs. Theta, time and context are two-dimensional arrays with the same batch dimensions. + t_min: minimal time point for ODE integration + t_max: maximal time point for ODE integration """ - def __init__(self, n_dimension: int, transform: Callable, t_max=50): - """Conditional continuous normalizing flow. + def __init__( + self, + n_dimension: int, + transform: Callable, + t_min: float = 0.001, + t_max: float = 50.0, + ): + """Construct a consistency model. Args: n_dimension: the dimensionality of the modelled space @@ -32,11 +38,14 @@ def __init__(self, n_dimension: int, transform: Callable, t_max=50): take as input arguments named 'theta', 'time', 'context' and **kwargs. Theta, time and context are two-dimensional arrays with the same batch dimensions. + t_min: minimal time point for ODE integration + t_max: maximal time point for ODE integration """ super().__init__() self._n_dimension = n_dimension self._network = transform self._t_max = t_max + self._t_min = t_min self._base_distribution = distrax.Normal(jnp.zeros(n_dimension), 1.0) def __call__(self, method, **kwargs): @@ -50,16 +59,26 @@ def __call__(self, method, **kwargs): """ return getattr(self, method)(**kwargs) - def sample(self, context): - """Sample from the pushforward. + def sample(self, context, **kwargs): + """Sample from the consistency model. Args: context: array of conditioning variables + kwargs: keyword argumente like 'is_training' """ - theta_0 = self._base_distribution.sample( + noise = self._base_distribution.sample( seed=hk.next_rng_key(), sample_shape=(context.shape[0],) ) - y_hat = self.vector_field(theta_0, self._t_max, context) + y_hat = self.vector_field(noise, self._t_max, context, **kwargs) + + noise = self._base_distribution.sample( + seed=hk.next_rng_key(), sample_shape=(y_hat.shape[0],) + ) + tme = self._t_min + (self._t_max - self._t_min) / 2 + noise = jnp.sqrt(jnp.square(tme) - jnp.square(self._t_min)) * noise + y_tme = y_hat + noise + y_hat = self.vector_field(y_tme, tme, context, **kwargs) + return y_hat def vector_field(self, theta, time, context, **kwargs): @@ -77,43 +96,6 @@ def vector_field(self, theta, time, context, **kwargs): return self._network(theta=theta, time=time, context=context, **kwargs) -# pylint: disable=too-many-arguments -class _ResnetBlock(hk.Module): - """A block for a 1d residual network.""" - - def __init__( - self, - hidden_size: int, - activation: Callable = jax.nn.relu, - dropout_rate: float = 0.2, - do_batch_norm: bool = False, - batch_norm_decay: float = 0.1, - ): - super().__init__() - self.hidden_size = hidden_size - self.activation = activation - self.do_batch_norm = do_batch_norm - self.dropout_rate = dropout_rate - self.batch_norm_decay = batch_norm_decay - - def __call__(self, inputs, context, is_training=False): - outputs = inputs - if self.do_batch_norm: - outputs = hk.BatchNorm(True, True, self.batch_norm_decay)( - outputs, is_training=is_training - ) - outputs = hk.Linear(self.hidden_size)(outputs) - outputs = self.activation(outputs) - if is_training: - outputs = hk.dropout( - rng=hk.next_rng_key(), rate=self.dropout_rate, x=outputs - ) - outputs = hk.Linear(self.hidden_size)(outputs) - context_proj = hk.Linear(inputs.shape[-1])(context) - outputs = glu(jnp.concatenate([outputs, context_proj], axis=-1)) - return outputs + inputs - - # pylint: disable=too-many-arguments class _CMResnet(hk.Module): """A simplified 1-d residual network.""" @@ -127,8 +109,8 @@ def __init__( dropout_rate: float = 0.0, do_batch_norm: bool = False, batch_norm_decay: float = 0.1, - eps: float = 0.001, - sigma_data:float = 1.0 + t_min: float = 0.001, + sigma_data: float = 1.0, ): super().__init__() self.n_layers = n_layers @@ -140,9 +122,9 @@ def __init__( self.batch_norm_decay = batch_norm_decay self.sigma_data = sigma_data self.var_data = self.sigma_data**2 - self.eps = eps + self.t_min = t_min - def __call__(self, theta, time, context, is_training=False, **kwargs): + def __call__(self, theta, time, context, is_training, **kwargs): outputs = context t_theta_embedding = jnp.concatenate( [ @@ -164,19 +146,17 @@ def __call__(self, theta, time, context, is_training=False, **kwargs): outputs = self.activation(outputs) outputs = hk.Linear(self.n_dimension)(outputs) - # TODO(simon): how is sigma_data chosen automatically? - # in the meantime set it to 1 and use batch norm before - #outputs = hk.BatchNorm(True, True, self.batch_norm_decay)(outputs, is_training=is_training) + # TODO(simon): dan we choose sigma automatically? out_skip = self._c_skip(time) * theta + self._c_out(time) * outputs return out_skip def _c_skip(self, time): - return self.var_data / ((time - self.eps) ** 2 + self.var_data) + return self.var_data / ((time - self.t_min) ** 2 + self.var_data) def _c_out(self, time): return ( self.sigma_data - * (time - self.eps) + * (time - self.t_min) / jnp.sqrt(self.var_data + time**2) ) @@ -189,14 +169,13 @@ def make_consistency_model( dropout_rate: float = 0.2, do_batch_norm: bool = False, batch_norm_decay: float = 0.2, - t_max: float=50, - epsilon=0.001, - sigma_data:float=1.0 + t_min: float = 0.001, + t_max: float = 50.0, + sigma_data: float = 1.0, ): - """Create a conditional continuous normalizing flow. + """Create a consistency model. - The CCNF uses a residual network as transformer which is created - automatically. + The consistency model uses a residual network as score network. Args: n_dimension: dimensionality of modelled space @@ -206,8 +185,12 @@ def make_consistency_model( dropout_rate: dropout rate to use in resnet blocks do_batch_norm: use batch normalization or not batch_norm_decay: decay rate of EMA in batch norm layer + t_min: minimal time point for ODE integration + t_max: maximal time point for ODE integration + sigma_data: the standard deviation of the data :) + Returns: - returns a conditional continuous normalizing flow + returns a consistency model """ @hk.transform @@ -220,10 +203,10 @@ def _cm(method, **kwargs): do_batch_norm=do_batch_norm, dropout_rate=dropout_rate, batch_norm_decay=batch_norm_decay, - eps=epsilon, + t_min=t_min, sigma_data=sigma_data, ) - cm = ConsistencyModel(n_dimension, nn, t_max=t_max) + cm = ConsistencyModel(n_dimension, nn, t_min=t_min, t_max=t_max) return cm(method, **kwargs) return _cm diff --git a/sbijax/_src/scmpe.py b/sbijax/_src/scmpe.py index 8fa485c..547fd8c 100644 --- a/sbijax/_src/scmpe.py +++ b/sbijax/_src/scmpe.py @@ -8,31 +8,57 @@ from jax import random as jr from tqdm import tqdm -from sbijax._src._sne_base import SNE +from sbijax._src.sfmpe import SFMPE from sbijax._src.util.early_stopping import EarlyStopping def _alpha_t(time): - return 1 / (_time_schedule(time + 1) - _time_schedule(time)) + return 1.0 / (_time_schedule(time + 1) - _time_schedule(time)) -def _time_schedule(n, rho=7, eps=0.001, T_max=50, N=1000): - left = eps ** (1 / rho) - right = T_max ** (1 / rho) - eps ** (1 / rho) - right = (n - 1) / (N - 1) * right +def _time_schedule(n, rho=7, t_min=0.001, t_max=50, n_inters=1000): + left = t_min ** (1 / rho) + right = t_max ** (1 / rho) - t_min ** (1 / rho) + right = (n - 1) / (n_inters - 1) * right return (left + right) ** rho +def _discretization_schedule(n_iter, max_iter=1000): + s0, s1 = 10, 50 + nk = ( + (n_iter / max_iter) * (jnp.square(s1 + 1) - jnp.square(s0)) + + jnp.square(s0) + - 1 + ) + nk = jnp.ceil(jnp.sqrt(nk)) + 1 + return nk + + # pylint: disable=too-many-locals def _consistency_loss( - params, ema_params, rng_key, apply_fn, is_training=False, **batch + params, + ema_params, + rng_key, + apply_fn, + n_iter, + t_min, + t_max, + is_training=False, + **batch, ): theta = batch["theta"] + nk = _discretization_schedule(n_iter) t_key, rng_key = jr.split(rng_key) - time_idx = jr.randint(t_key, shape=(theta.shape[0],), minval=1, maxval=1000 - 1) - tn = _time_schedule(time_idx).reshape(-1, 1) - tnp1 = _time_schedule(time_idx + 1).reshape(-1, 1) + time_idx = jr.randint( + t_key, shape=(theta.shape[0],), minval=1, maxval=nk - 1 + ) + tn = _time_schedule( + time_idx, t_min=t_min, t_max=t_max, n_inters=nk + ).reshape(-1, 1) + tnp1 = _time_schedule( + time_idx + 1, t_min=t_min, t_max=t_max, n_inters=nk + ).reshape(-1, 1) noise_key, rng_key = jr.split(rng_key) noise = jr.normal(noise_key, shape=(*theta.shape,)) @@ -56,20 +82,20 @@ def _consistency_loss( context=batch["y"], is_training=is_training, ) - mse = jnp.mean(jnp.square(fnp1 - fn), axis=1) + mse = jnp.sqrt(jnp.mean(jnp.square(fnp1 - fn), axis=1)) loss = _alpha_t(time_idx) * mse return jnp.mean(loss) # pylint: disable=too-many-arguments,unused-argument,useless-parent-delegation -class SCMPE(SNE): +class SCMPE(SFMPE): r"""Sequential consistency model posterior estimation. Implements a sequential version of the CMPE algorithm introduced in [1]_. For all rounds $r > 1$ parameter samples :math:`\theta \sim \hat{p}^r(\theta)` are drawn from the approximate posterior instead of the prior when computing consistency - loss + loss. Note that the implementation does not strictly follow the paper. Args: model_fns: a tuple of tuples. The first element is a tuple that @@ -77,6 +103,8 @@ class SCMPE(SNE): log-probability of a data point. The second element is a simulator function. network: a neural network + t_min: minimal time point for ODE integration + t_max: maximal time point for ODE integration Examples: >>> import distrax @@ -91,13 +119,13 @@ class SCMPE(SNE): >>> estim = SCMPE(fns, net) References: - .. [1] Wildberger, Jonas, et al. "Flow Matching for Scalable - Simulation-Based Inference." Advances in Neural Information - Processing Systems, 2024. + .. [1] Schmitt, Marvin, et al. "Consistency Models for Scalable and + Fast Simulation-Based Inference". + arXiv preprint arXiv:2312.05440, 2023. """ - def __init__(self, model_fns, network): - """Construct a FMPE object. + def __init__(self, model_fns, network, t_max=50.0, t_min=0.001): + """Construct a SCMPE object. Args: model_fns: a tuple of tuples. The first element is a tuple that @@ -105,53 +133,12 @@ def __init__(self, model_fns, network): log-probability of a data point. The second element is a simulator function. network: network: a neural network + t_min: minimal time point for ODE integration + t_max: maximal time point for ODE integration """ super().__init__(model_fns, network) - - # pylint: disable=arguments-differ,too-many-locals - def fit( - self, - rng_key, - data, - *, - optimizer=optax.adam(0.0003), - n_iter=1000, - batch_size=100, - percentage_data_as_validation_set=0.1, - n_early_stopping_patience=10, - **kwargs, - ): - """Fit the model. - - Args: - rng_key: a jax random key - data: data set obtained from calling - `simulate_data_and_possibly_append` - optimizer: an optax optimizer object - n_iter: maximal number of training iterations per round - batch_size: batch size used for training the model - percentage_data_as_validation_set: percentage of the simulated - data that is used for validation and early stopping - n_early_stopping_patience: number of iterations of no improvement - of training the flow before stopping optimisation\ - - Returns: - a tuple of parameters and a tuple of the training information - """ - itr_key, rng_key = jr.split(rng_key) - train_iter, val_iter = self.as_iterators( - itr_key, data, batch_size, percentage_data_as_validation_set - ) - params, losses = self._fit_model_single_round( - seed=rng_key, - train_iter=train_iter, - val_iter=val_iter, - optimizer=optimizer, - n_iter=n_iter, - n_early_stopping_patience=n_early_stopping_patience, - ) - - return params, losses + self._t_min = t_min + self._t_max = t_max # pylint: disable=undefined-loop-variable def _fit_model_single_round( @@ -170,18 +157,22 @@ def _fit_model_single_round( loss_fn = jax.jit( partial( - _consistency_loss, apply_fn=self.model.apply, is_training=False + _consistency_loss, + apply_fn=self.model.apply, + is_training=True, + t_max=self._t_max, + t_min=self._t_min, ) ) @jax.jit def ema_update(params, avg_params): - return optax.incremental_update(avg_params, params , step_size=0.01) + return optax.incremental_update(avg_params, params, step_size=0.01) @jax.jit - def step(params, ema_params, rng, state, **batch): + def step(params, ema_params, rng, state, n_iter, **batch): loss, grads = jax.value_and_grad(loss_fn)( - params, ema_params, rng, **batch + params, ema_params, rng, n_iter=n_iter, **batch ) updates, new_state = optimizer.update(grads, state, params) new_params = optax.apply_updates(params, updates) @@ -189,7 +180,7 @@ def step(params, ema_params, rng, state, **batch): return loss, new_params, new_ema_params, new_state losses = np.zeros([n_iter, 2]) - early_stop = EarlyStopping(1e-3, n_early_stopping_patience*2) + early_stop = EarlyStopping(1e-3, n_early_stopping_patience * 2) best_params, best_loss = None, np.inf logging.info("training model") for i in tqdm(range(n_iter)): @@ -198,14 +189,14 @@ def step(params, ema_params, rng, state, **batch): for batch in train_iter: train_key, rng_key = jr.split(rng_key) batch_loss, params, ema_params, state = step( - params, ema_params, train_key, state, **batch + params, ema_params, train_key, state, n_iter + 1, **batch ) train_loss += batch_loss * ( batch["y"].shape[0] / train_iter.num_samples ) val_key, rng_key = jr.split(rng_key) validation_loss = self._validation_loss( - val_key, params, ema_params, val_iter + val_key, params, ema_params, n_iter, val_iter ) losses[i] = jnp.array([train_loss, validation_loss]) @@ -228,14 +219,19 @@ def _init_params(self, rng_key, **init_data): theta=init_data["theta"], time=times, context=init_data["y"], - is_training=False, + is_training=True, ) return params - def _validation_loss(self, rng_key, params, ema_params, val_iter): + def _validation_loss(self, rng_key, params, ema_params, n_iter, val_iter): loss_fn = jax.jit( partial( - _consistency_loss, apply_fn=self.model.apply, is_training=False + _consistency_loss, + apply_fn=self.model.apply, + is_training=False, + t_max=self._t_max, + t_min=self._t_min, + n_iter=n_iter, ) ) @@ -248,47 +244,3 @@ def body_fn(batch_key, **batch): val_key, rng_key = jr.split(rng_key) loss += body_fn(val_key, **batch) return loss - - def sample_posterior( - self, rng_key, params, observable, *, n_samples=4_000, **kwargs - ): - r"""Sample from the approximate posterior. - - Args: - rng_key: a jax random key - params: a pytree of neural network parameters - observable: observation to condition on - n_samples: number of samples to draw - - Returns: - returns an array of samples from the posterior distribution of - dimension (n_samples \times p) - """ - observable = jnp.atleast_2d(observable) - - thetas = None - n_curr = n_samples - n_total_simulations_round = 0 - while n_curr > 0: - n_sim = jnp.minimum(200, jnp.maximum(200, n_curr)) - n_total_simulations_round += n_sim - sample_key, rng_key = jr.split(rng_key) - proposal = self.model.apply( - params, - sample_key, - method="sample", - context=jnp.tile(observable, [n_sim, 1]), - ) - proposal_probs = self.prior_log_density_fn(proposal) - proposal_accepted = proposal[jnp.isfinite(proposal_probs)] - if thetas is None: - thetas = proposal_accepted - else: - thetas = jnp.vstack([thetas, proposal_accepted]) - n_curr -= proposal_accepted.shape[0] - - self.n_total_simulations += n_total_simulations_round - return ( - thetas[:n_samples], - thetas.shape[0] / n_total_simulations_round, - ) diff --git a/sbijax/_src/sfmpe.py b/sbijax/_src/sfmpe.py index d426c68..440434d 100644 --- a/sbijax/_src/sfmpe.py +++ b/sbijax/_src/sfmpe.py @@ -65,7 +65,8 @@ class SFMPE(SNE): For all rounds $r > 1$ parameter samples :math:`\theta \sim \hat{p}^r(\theta)` are drawn from the approximate posterior instead of the prior when computing the flow - matching loss. + matching loss. Note that the implementation does not strictly follow the + paper. Args: model_fns: a tuple of tuples. The first element is a tuple that @@ -93,7 +94,7 @@ class SFMPE(SNE): """ def __init__(self, model_fns, density_estimator): - """Construct a FMPE object. + """Construct a SFMPE object. Args: model_fns: a tuple of tuples. The first element is a tuple that @@ -261,6 +262,7 @@ def sample_posterior( sample_key, method="sample", context=jnp.tile(observable, [n_sim, 1]), + is_training=False, ) proposal_probs = self.prior_log_density_fn(proposal) proposal_accepted = proposal[jnp.isfinite(proposal_probs)] From 0323cb29fb62c6a881d6ceb722c1f0b8c361cd98 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Thu, 29 Feb 2024 11:27:31 +0100 Subject: [PATCH 4/6] fix tests --- sbijax/_src/nn/consistency_model.py | 2 +- sbijax/_src/nn/continuous_normalizing_flow.py | 4 +- sbijax/_src/scmpe_test.py | 60 +++++++++++++++++++ sbijax/_src/sfmpe_test.py | 60 +++++++++++++++++++ 4 files changed, 123 insertions(+), 3 deletions(-) create mode 100644 sbijax/_src/scmpe_test.py create mode 100644 sbijax/_src/sfmpe_test.py diff --git a/sbijax/_src/nn/consistency_model.py b/sbijax/_src/nn/consistency_model.py index a871d2b..24fc43d 100644 --- a/sbijax/_src/nn/consistency_model.py +++ b/sbijax/_src/nn/consistency_model.py @@ -96,7 +96,7 @@ def vector_field(self, theta, time, context, **kwargs): return self._network(theta=theta, time=time, context=context, **kwargs) -# pylint: disable=too-many-arguments +# pylint: disable=too-many-arguments,too-many-instance-attributes class _CMResnet(hk.Module): """A simplified 1-d residual network.""" diff --git a/sbijax/_src/nn/continuous_normalizing_flow.py b/sbijax/_src/nn/continuous_normalizing_flow.py index 1c405e0..741d865 100644 --- a/sbijax/_src/nn/continuous_normalizing_flow.py +++ b/sbijax/_src/nn/continuous_normalizing_flow.py @@ -47,7 +47,7 @@ def __call__(self, method, **kwargs): """ return getattr(self, method)(**kwargs) - def sample(self, context): + def sample(self, context, **kwargs): """Sample from the pushforward. Args: @@ -61,7 +61,7 @@ def ode_func(time, theta_t): theta_t = theta_t.reshape(-1, self._n_dimension) time = jnp.full((theta_t.shape[0], 1), time) ret = self.vector_field( - theta=theta_t, time=time, context=context, is_training=False + theta=theta_t, time=time, context=context, **kwargs ) return ret.reshape(-1) diff --git a/sbijax/_src/scmpe_test.py b/sbijax/_src/scmpe_test.py new file mode 100644 index 0000000..8f6a38e --- /dev/null +++ b/sbijax/_src/scmpe_test.py @@ -0,0 +1,60 @@ +# pylint: skip-file + +import distrax +import haiku as hk +from jax import numpy as jnp + +from sbijax import SCMPE +from sbijax.nn import make_consistency_model + + +def prior_model_fns(): + p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1) + return p.sample, p.log_prob + + +def simulator_fn(seed, theta): + p = distrax.MultivariateNormalDiag(theta, 0.1 * jnp.ones_like(theta)) + y = p.sample(seed=seed) + return y + + +def log_density_fn(theta, y): + prior = distrax.Uniform(jnp.full(2, -3.0), jnp.full(2, 3.0)) + likelihood = distrax.MultivariateNormalDiag( + theta, 0.1 * jnp.ones_like(theta) + ) + + lp = jnp.sum(prior.log_prob(theta)) + jnp.sum(likelihood.log_prob(y)) + return lp + + +def test_scmpe(): + rng_seq = hk.PRNGSequence(0) + y_observed = jnp.array([-1.0, 1.0]) + + prior_simulator_fn, prior_logdensity_fn = prior_model_fns() + fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn + + estim = SCMPE(fns, make_consistency_model(2)) + data, params = None, {} + for i in range(2): + data, _ = estim.simulate_data_and_possibly_append( + next(rng_seq), + params=params, + observable=y_observed, + data=data, + n_simulations=100, + n_chains=2, + n_samples=200, + n_warmup=100, + ) + params, info = estim.fit(next(rng_seq), data=data, n_iter=2) + _ = estim.sample_posterior( + next(rng_seq), + params, + y_observed, + n_chains=2, + n_samples=200, + n_warmup=100, + ) diff --git a/sbijax/_src/sfmpe_test.py b/sbijax/_src/sfmpe_test.py new file mode 100644 index 0000000..102ca07 --- /dev/null +++ b/sbijax/_src/sfmpe_test.py @@ -0,0 +1,60 @@ +# pylint: skip-file + +import distrax +import haiku as hk +from jax import numpy as jnp + +from sbijax import SFMPE +from sbijax.nn import make_ccnf + + +def prior_model_fns(): + p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1) + return p.sample, p.log_prob + + +def simulator_fn(seed, theta): + p = distrax.MultivariateNormalDiag(theta, 0.1 * jnp.ones_like(theta)) + y = p.sample(seed=seed) + return y + + +def log_density_fn(theta, y): + prior = distrax.Uniform(jnp.full(2, -3.0), jnp.full(2, 3.0)) + likelihood = distrax.MultivariateNormalDiag( + theta, 0.1 * jnp.ones_like(theta) + ) + + lp = jnp.sum(prior.log_prob(theta)) + jnp.sum(likelihood.log_prob(y)) + return lp + + +def test_sfmpe(): + rng_seq = hk.PRNGSequence(0) + y_observed = jnp.array([-1.0, 1.0]) + + prior_simulator_fn, prior_logdensity_fn = prior_model_fns() + fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn + + estim = SFMPE(fns, make_ccnf(2)) + data, params = None, {} + for i in range(2): + data, _ = estim.simulate_data_and_possibly_append( + next(rng_seq), + params=params, + observable=y_observed, + data=data, + n_simulations=100, + n_chains=2, + n_samples=200, + n_warmup=100, + ) + params, info = estim.fit(next(rng_seq), data=data, n_iter=2) + _ = estim.sample_posterior( + next(rng_seq), + params, + y_observed, + n_chains=2, + n_samples=200, + n_warmup=100, + ) From 16b9b3b16823b05db1282111c7a05787e0478ffe Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Thu, 29 Feb 2024 11:29:12 +0100 Subject: [PATCH 5/6] fix tests --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f14e859..0926110 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ - [Contrastive Neural Ratio Estimation](https://arxiv.org/abs/2210.06170) (short `SNR`) - [Neural Approximate Sufficient Statistics](https://arxiv.org/abs/2010.10079) (`SNASS`) - [Neural Approximate Slice Sufficient Statistics](https://openreview.net/forum?id=jjzJ768iV1) (`SNASSS`) -- [Flow matching posterior estimation](https://openreview.net/forum?id=jjzJ768iV1) (`SFMPE`) +- [Flow matching posterior estimation](https://arxiv.org/abs/2305.17161) (`SFMPE`) - [Consistency model posterior estimation](https://arxiv.org/abs/2312.05440) (`SCMPE`) where the acronyms in parentheses denote the names of the methods in `sbijax`. From 29a2a9bc2279e8517e0f434450de7b23173ea17f Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Thu, 29 Feb 2024 11:35:04 +0100 Subject: [PATCH 6/6] fix lints :) --- sbijax/_src/scmpe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sbijax/_src/scmpe.py b/sbijax/_src/scmpe.py index 547fd8c..f9b2d83 100644 --- a/sbijax/_src/scmpe.py +++ b/sbijax/_src/scmpe.py @@ -34,7 +34,7 @@ def _discretization_schedule(n_iter, max_iter=1000): return nk -# pylint: disable=too-many-locals +# pylint: disable=too-many-locals,too-many-arguments def _consistency_loss( params, ema_params, @@ -223,6 +223,7 @@ def _init_params(self, rng_key, **init_data): ) return params + # pylint: disable=arguments-differ def _validation_loss(self, rng_key, params, ema_params, n_iter, val_iter): loss_fn = jax.jit( partial(