diff --git a/README.md b/README.md index 7aac6bb..6ca94a5 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,7 @@ pip install git+https://github.com/dirmeier/sbijax@ ## Acknowledgements +> [!NOTE] > 📝 The API of the package is heavily inspired by the excellent Pytorch-based [`sbi`](https://github.com/sbi-dev/sbi) package which is substantially more feature-complete and user-friendly, and better documented. diff --git a/docs/sbijax.rst b/docs/sbijax.rst index 332fa0c..5ca2633 100644 --- a/docs/sbijax.rst +++ b/docs/sbijax.rst @@ -16,6 +16,7 @@ Methods SNP SNR SFMPE + SCMPE SNASS SNASSS @@ -46,6 +47,12 @@ SNR SFMPE ~~~~~ +.. autoclass:: SFMPE + :members: fit, simulate_data_and_possibly_append, sample_posterior + +SCMPE +~~~~~ + .. autoclass:: SFMPE :members: fit, simulate_data_and_possibly_append, sample_posterior diff --git a/examples/bivariate_gaussian_cfmpe.py b/examples/bivariate_gaussian_cfmpe.py new file mode 100644 index 0000000..0c1570c --- /dev/null +++ b/examples/bivariate_gaussian_cfmpe.py @@ -0,0 +1,85 @@ +""" +Example using consistency model posterior estimation on a bivariate Gaussian +""" + +import distrax +import haiku as hk +import matplotlib.pyplot as plt +import optax +import seaborn as sns +from jax import numpy as jnp +from jax import random as jr + +from sbijax import SCMPE +from sbijax.nn import ConsistencyModel + + +def prior_model_fns(): + p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1) + return p.sample, p.log_prob + + +def simulator_fn(seed, theta): + p = distrax.Normal(jnp.zeros_like(theta), 1.0) + y = theta + p.sample(seed=seed) + return y + + +def make_model(dim): + @hk.transform + def _mlp(method, **kwargs): + def _c_skip(time): + return 1 / ((time - 0.001) ** 2 + 1) + + def _c_out(time): + return 1.0 * (time - 0.001) / jnp.sqrt(1 + time ** 2) + def _nn(theta, time, context, **kwargs): + ins = jnp.concatenate([theta, time, context], axis=-1) + outs = hk.nets.MLP([64, 64, dim])(ins) + out_skip = _c_skip(time) * theta + _c_out(time) * outs + return out_skip + + cm = ConsistencyModel(dim, _nn) + return cm(method, **kwargs) + + return _mlp + + +def run(): + y_observed = jnp.array([2.0, -2.0]) + + prior_simulator_fn, prior_logdensity_fn = prior_model_fns() + fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn + + estim = SCMPE(fns, make_model(2)) + optimizer = optax.adam(1e-3) + + data, params = None, {} + for i in range(2): + data, _ = estim.simulate_data_and_possibly_append( + jr.fold_in(jr.PRNGKey(1), i), + params=params, + observable=y_observed, + data=data, + ) + params, info = estim.fit( + jr.fold_in(jr.PRNGKey(2), i), + data=data, + optimizer=optimizer, + ) + + + rng_key = jr.PRNGKey(23) + post_samples, _ = estim.sample_posterior(rng_key, params, y_observed) + print(post_samples) + fig, axes = plt.subplots(2) + for i, ax in enumerate(axes): + sns.histplot(post_samples[:, i], color="darkblue", ax=ax) + ax.set_xlim([-3.0, 3.0]) + sns.despine() + plt.tight_layout() + plt.show() + + +if __name__ == "__main__": + run() diff --git a/sbijax/__init__.py b/sbijax/__init__.py index 646b951..75ada41 100644 --- a/sbijax/__init__.py +++ b/sbijax/__init__.py @@ -2,11 +2,12 @@ sbijax: Simulation-based inference in JAX """ -__version__ = "0.1.9" +__version__ = "0.2.0" from sbijax._src.abc.smc_abc import SMCABC -from sbijax._src.fmpe import SFMPE +from sbijax._src.scmpe import SCMPE +from sbijax._src.sfmpe import SFMPE from sbijax._src.snass import SNASS from sbijax._src.snasss import SNASSS from sbijax._src.snl import SNL diff --git a/sbijax/_src/nn/consistency_model.py b/sbijax/_src/nn/consistency_model.py new file mode 100644 index 0000000..02259e3 --- /dev/null +++ b/sbijax/_src/nn/consistency_model.py @@ -0,0 +1,229 @@ +from typing import Callable + +import distrax +import haiku as hk +import jax +from jax import numpy as jnp +from jax.nn import glu +from scipy import integrate + +__all__ = ["ConsistencyModel", "make_consistency_model"] + +from sbijax._src.nn.make_resnet import _Resnet + + +class ConsistencyModel(hk.Module): + """Conditional continuous normalizing flow. + + Args: + n_dimension: the dimensionality of the modelled space + transform: a haiku module. The transform is a callable that has to + take as input arguments named 'theta', 'time', 'context' and + **kwargs. Theta, time and context are two-dimensional arrays + with the same batch dimensions. + """ + + def __init__(self, n_dimension: int, transform: Callable, t_max=50): + """Conditional continuous normalizing flow. + + Args: + n_dimension: the dimensionality of the modelled space + transform: a haiku module. The transform is a callable that has to + take as input arguments named 'theta', 'time', 'context' and + **kwargs. Theta, time and context are two-dimensional arrays + with the same batch dimensions. + """ + super().__init__() + self._n_dimension = n_dimension + self._network = transform + self._t_max = t_max + self._base_distribution = distrax.Normal(jnp.zeros(n_dimension), 1.0) + + def __call__(self, method, **kwargs): + """Aplpy the flow. + + Args: + method (str): method to call + + Keyword Args: + keyword arguments for the called method: + """ + return getattr(self, method)(**kwargs) + + def sample(self, context): + """Sample from the pushforward. + + Args: + context: array of conditioning variables + """ + theta_0 = self._base_distribution.sample( + seed=hk.next_rng_key(), sample_shape=(context.shape[0],) + ) + y_hat = self.vector_field(theta_0, self._t_max, context) + return y_hat + + def vector_field(self, theta, time, context, **kwargs): + """Compute the vector field. + + Args: + theta: array of parameters + time: time variables + context: array of conditioning variables + + Keyword Args: + keyword arguments that aer passed tothe neural network + """ + time = jnp.full((theta.shape[0], 1), time) + return self._network(theta=theta, time=time, context=context, **kwargs) + + +# pylint: disable=too-many-arguments +class _ResnetBlock(hk.Module): + """A block for a 1d residual network.""" + + def __init__( + self, + hidden_size: int, + activation: Callable = jax.nn.relu, + dropout_rate: float = 0.2, + do_batch_norm: bool = False, + batch_norm_decay: float = 0.1, + ): + super().__init__() + self.hidden_size = hidden_size + self.activation = activation + self.do_batch_norm = do_batch_norm + self.dropout_rate = dropout_rate + self.batch_norm_decay = batch_norm_decay + + def __call__(self, inputs, context, is_training=False): + outputs = inputs + if self.do_batch_norm: + outputs = hk.BatchNorm(True, True, self.batch_norm_decay)( + outputs, is_training=is_training + ) + outputs = hk.Linear(self.hidden_size)(outputs) + outputs = self.activation(outputs) + if is_training: + outputs = hk.dropout( + rng=hk.next_rng_key(), rate=self.dropout_rate, x=outputs + ) + outputs = hk.Linear(self.hidden_size)(outputs) + context_proj = hk.Linear(inputs.shape[-1])(context) + outputs = glu(jnp.concatenate([outputs, context_proj], axis=-1)) + return outputs + inputs + + +# pylint: disable=too-many-arguments +class _CMResnet(hk.Module): + """A simplified 1-d residual network.""" + + def __init__( + self, + n_layers: int, + n_dimension: int, + hidden_size: int, + activation: Callable = jax.nn.relu, + dropout_rate: float = 0.0, + do_batch_norm: bool = False, + batch_norm_decay: float = 0.1, + eps: float = 0.001, + sigma_data:float = 1.0 + ): + super().__init__() + self.n_layers = n_layers + self.n_dimension = n_dimension + self.hidden_size = hidden_size + self.activation = activation + self.do_batch_norm = do_batch_norm + self.dropout_rate = dropout_rate + self.batch_norm_decay = batch_norm_decay + self.sigma_data = sigma_data + self.var_data = self.sigma_data**2 + self.eps = eps + + def __call__(self, theta, time, context, is_training=False, **kwargs): + outputs = context + t_theta_embedding = jnp.concatenate( + [ + hk.Linear(self.n_dimension)(theta), + hk.Linear(self.n_dimension)(time), + ], + axis=-1, + ) + outputs = hk.Linear(self.hidden_size)(outputs) + outputs = self.activation(outputs) + for _ in range(self.n_layers): + outputs = _ResnetBlock( + hidden_size=self.hidden_size, + activation=self.activation, + dropout_rate=self.dropout_rate, + do_batch_norm=self.do_batch_norm, + batch_norm_decay=self.batch_norm_decay, + )(outputs, context=t_theta_embedding, is_training=is_training) + outputs = self.activation(outputs) + outputs = hk.Linear(self.n_dimension)(outputs) + + # TODO(simon): how is sigma_data chosen automatically? + # in the meantime set it to 1 and use batch norm before + #outputs = hk.BatchNorm(True, True, self.batch_norm_decay)(outputs, is_training=is_training) + out_skip = self._c_skip(time) * theta + self._c_out(time) * outputs + return out_skip + + def _c_skip(self, time): + return self.var_data / ((time - self.eps) ** 2 + self.var_data) + + def _c_out(self, time): + return ( + self.sigma_data + * (time - self.eps) + / jnp.sqrt(self.var_data + time**2) + ) + + +def make_consistency_model( + n_dimension: int, + n_layers: int = 2, + hidden_size: int = 64, + activation: Callable = jax.nn.tanh, + dropout_rate: float = 0.2, + do_batch_norm: bool = False, + batch_norm_decay: float = 0.2, + t_max: float=50, + epsilon=0.001, + sigma_data:float=1.0 +): + """Create a conditional continuous normalizing flow. + + The CCNF uses a residual network as transformer which is created + automatically. + + Args: + n_dimension: dimensionality of modelled space + n_layers: number of resnet blocks + hidden_size: sizes of hidden layers for each resnet block + activation: a jax activation function + dropout_rate: dropout rate to use in resnet blocks + do_batch_norm: use batch normalization or not + batch_norm_decay: decay rate of EMA in batch norm layer + Returns: + returns a conditional continuous normalizing flow + """ + + @hk.transform + def _cm(method, **kwargs): + nn = _CMResnet( + n_layers=n_layers, + n_dimension=n_dimension, + hidden_size=hidden_size, + activation=activation, + do_batch_norm=do_batch_norm, + dropout_rate=dropout_rate, + batch_norm_decay=batch_norm_decay, + eps=epsilon, + sigma_data=sigma_data, + ) + cm = ConsistencyModel(n_dimension, nn, t_max=t_max) + return cm(method, **kwargs) + + return _cm diff --git a/sbijax/_src/scmpe.py b/sbijax/_src/scmpe.py new file mode 100644 index 0000000..e5fed87 --- /dev/null +++ b/sbijax/_src/scmpe.py @@ -0,0 +1,294 @@ +from functools import partial + +import jax +import numpy as np +import optax +from absl import logging +from jax import numpy as jnp +from jax import random as jr +from tqdm import tqdm + +from sbijax._src._sne_base import SNE +from sbijax._src.util.early_stopping import EarlyStopping + + +def _alpha_t(time): + return 1 / (_time_schedule(time + 1) - _time_schedule(time)) + + +def _time_schedule(n, rho=7, eps=0.001, T_max=50, N=1000): + left = eps ** (1 / rho) + right = T_max ** (1 / rho) - eps ** (1 / rho) + right = (n - 1) / (N - 1) * right + return (left + right) ** rho + + +# pylint: disable=too-many-locals +def _consistency_loss( + params, ema_params, rng_key, apply_fn, is_training=False, **batch +): + theta = batch["theta"] + + t_key, rng_key = jr.split(rng_key) + time_idx = jr.randint(t_key, shape=(theta.shape[0],), minval=1, maxval=1000 - 1) + tn = _time_schedule(time_idx).reshape(-1, 1) + tnp1 = _time_schedule(time_idx + 1).reshape(-1, 1) + + noise_key, rng_key = jr.split(rng_key) + noise = jr.normal(noise_key, shape=(*theta.shape,)) + + train_rng, rng_key = jr.split(rng_key) + fnp1 = apply_fn( + params, + train_rng, + method="vector_field", + theta=theta + tnp1 * noise, + time=tnp1, + context=batch["y"], + is_training=is_training, + ) + fn = apply_fn( + ema_params, + train_rng, + method="vector_field", + theta=theta + tn * noise, + time=tn, + context=batch["y"], + is_training=is_training, + ) + mse = jnp.mean(jnp.square(fnp1 - fn), axis=1) + loss = _alpha_t(time_idx) * mse + return jnp.mean(loss) + + +# pylint: disable=too-many-arguments,unused-argument,useless-parent-delegation +class SCMPE(SNE): + r"""Sequential consistency model posterior estimation. + + Implements a sequential version of the CMPE algorithm introduced in [1]_. + For all rounds $r > 1$ parameter samples + :math:`\theta \sim \hat{p}^r(\theta)` are drawn from + the approximate posterior instead of the prior when computing consistency + loss + + Args: + model_fns: a tuple of tuples. The first element is a tuple that + consists of functions to sample and evaluate the + log-probability of a data point. The second element is a + simulator function. + network: a neural network + + Examples: + >>> import distrax + >>> from sbijax import SCMPE + >>> from sbijax.nn import make_consistency_model + >>> + >>> prior = distrax.Normal(0.0, 1.0) + >>> s = lambda seed, theta: distrax.Normal(theta, 1.0).sample(seed=seed) + >>> fns = (prior.sample, prior.log_prob), s + >>> net = make_consistency_model(1) + >>> + >>> estim = SCMPE(fns, net) + + References: + .. [1] Wildberger, Jonas, et al. "Flow Matching for Scalable + Simulation-Based Inference." Advances in Neural Information + Processing Systems, 2024. + """ + + def __init__(self, model_fns, network): + """Construct a FMPE object. + + Args: + model_fns: a tuple of tuples. The first element is a tuple that + consists of functions to sample and evaluate the + log-probability of a data point. The second element is a + simulator function. + network: network: a neural network + """ + super().__init__(model_fns, network) + + # pylint: disable=arguments-differ,too-many-locals + def fit( + self, + rng_key, + data, + *, + optimizer=optax.adam(0.0003), + n_iter=1000, + batch_size=100, + percentage_data_as_validation_set=0.1, + n_early_stopping_patience=10, + **kwargs, + ): + """Fit the model. + + Args: + rng_key: a jax random key + data: data set obtained from calling + `simulate_data_and_possibly_append` + optimizer: an optax optimizer object + n_iter: maximal number of training iterations per round + batch_size: batch size used for training the model + percentage_data_as_validation_set: percentage of the simulated + data that is used for validation and early stopping + n_early_stopping_patience: number of iterations of no improvement + of training the flow before stopping optimisation\ + + Returns: + a tuple of parameters and a tuple of the training information + """ + itr_key, rng_key = jr.split(rng_key) + train_iter, val_iter = self.as_iterators( + itr_key, data, batch_size, percentage_data_as_validation_set + ) + params, losses = self._fit_model_single_round( + seed=rng_key, + train_iter=train_iter, + val_iter=val_iter, + optimizer=optimizer, + n_iter=n_iter, + n_early_stopping_patience=n_early_stopping_patience, + ) + + return params, losses + + # pylint: disable=undefined-loop-variable + def _fit_model_single_round( + self, + seed, + train_iter, + val_iter, + optimizer, + n_iter, + n_early_stopping_patience, + ): + init_key, seed = jr.split(seed) + params = self._init_params(init_key, **next(iter(train_iter))) + ema_params = params.copy() + state = optimizer.init(params) + + loss_fn = jax.jit( + partial( + _consistency_loss, apply_fn=self.model.apply, is_training=False + ) + ) + + @jax.jit + def ema_update(params, avg_params): + return optax.incremental_update(params, avg_params, step_size=0.001) + + @jax.jit + def step(params, ema_params, rng, state, **batch): + loss, grads = jax.value_and_grad(loss_fn)( + params, ema_params, rng, **batch + ) + updates, new_state = optimizer.update(grads, state, params) + new_params = optax.apply_updates(params, updates) + new_ema_params = ema_update(new_params, ema_params) + return loss, new_params, new_ema_params, new_state + + losses = np.zeros([n_iter, 2]) + early_stop = EarlyStopping(1e-3, n_early_stopping_patience) + best_params, best_loss = None, np.inf + logging.info("training model") + for i in tqdm(range(n_iter)): + train_loss = 0.0 + rng_key = jr.fold_in(seed, i) + for batch in train_iter: + train_key, rng_key = jr.split(rng_key) + batch_loss, params, ema_params, state = step( + params, ema_params, train_key, state, **batch + ) + train_loss += batch_loss * ( + batch["y"].shape[0] / train_iter.num_samples + ) + val_key, rng_key = jr.split(rng_key) + validation_loss = self._validation_loss( + val_key, params, ema_params, val_iter + ) + losses[i] = jnp.array([train_loss, validation_loss]) + + _, early_stop = early_stop.update(validation_loss) + if early_stop.should_stop: + logging.info("early stopping criterion found") + break + if validation_loss < best_loss: + best_loss = validation_loss + best_params = params.copy() + + losses = jnp.vstack(losses)[: (i + 1), :] + return best_params, losses + + def _init_params(self, rng_key, **init_data): + times = jr.uniform(jr.PRNGKey(0), shape=(init_data["y"].shape[0], 1)) + params = self.model.init( + rng_key, + method="vector_field", + theta=init_data["theta"], + time=times, + context=init_data["y"], + is_training=False, + ) + return params + + def _validation_loss(self, rng_key, params, ema_params, val_iter): + loss_fn = jax.jit( + partial( + _consistency_loss, apply_fn=self.model.apply, is_training=False + ) + ) + + def body_fn(batch_key, **batch): + loss = loss_fn(params, ema_params, batch_key, **batch) + return loss * (batch["y"].shape[0] / val_iter.num_samples) + + loss = 0.0 + for batch in val_iter: + val_key, rng_key = jr.split(rng_key) + loss += body_fn(val_key, **batch) + return loss + + def sample_posterior( + self, rng_key, params, observable, *, n_samples=4_000, **kwargs + ): + r"""Sample from the approximate posterior. + + Args: + rng_key: a jax random key + params: a pytree of neural network parameters + observable: observation to condition on + n_samples: number of samples to draw + + Returns: + returns an array of samples from the posterior distribution of + dimension (n_samples \times p) + """ + observable = jnp.atleast_2d(observable) + + thetas = None + n_curr = n_samples + n_total_simulations_round = 0 + while n_curr > 0: + n_sim = jnp.minimum(200, jnp.maximum(200, n_curr)) + n_total_simulations_round += n_sim + sample_key, rng_key = jr.split(rng_key) + proposal = self.model.apply( + params, + sample_key, + method="sample", + context=jnp.tile(observable, [n_sim, 1]), + ) + proposal_probs = self.prior_log_density_fn(proposal) + proposal_accepted = proposal[jnp.isfinite(proposal_probs)] + if thetas is None: + thetas = proposal_accepted + else: + thetas = jnp.vstack([thetas, proposal_accepted]) + n_curr -= proposal_accepted.shape[0] + + self.n_total_simulations += n_total_simulations_round + return ( + thetas[:n_samples], + thetas.shape[0] / n_total_simulations_round, + ) diff --git a/sbijax/_src/fmpe.py b/sbijax/_src/sfmpe.py similarity index 100% rename from sbijax/_src/fmpe.py rename to sbijax/_src/sfmpe.py diff --git a/sbijax/nn/__init__.py b/sbijax/nn/__init__.py index 635b6f9..1972b41 100644 --- a/sbijax/nn/__init__.py +++ b/sbijax/nn/__init__.py @@ -1,5 +1,9 @@ """Neural network module.""" +from sbijax._src.nn.consistency_model import ( + ConsistencyModel, + make_consistency_model, +) from sbijax._src.nn.continuous_normalizing_flow import CCNF, make_ccnf from sbijax._src.nn.make_flows import ( make_affine_maf,