diff --git a/examples/bivariate_gaussian_fmpe.py b/examples/bivariate_gaussian_fmpe.py new file mode 100644 index 0000000..668b239 --- /dev/null +++ b/examples/bivariate_gaussian_fmpe.py @@ -0,0 +1,77 @@ +""" +Example using flow matching posterior estimation on a bivariate Gaussian +""" + +import distrax +import haiku as hk +import matplotlib.pyplot as plt +import optax +import seaborn as sns +from jax import numpy as jnp +from jax import random as jr + +from sbijax import FMPE +from sbijax.nn import CCNF + + +def prior_model_fns(): + p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1) + return p.sample, p.log_prob + + +def simulator_fn(seed, theta): + p = distrax.Normal(jnp.zeros_like(theta), 1.0) + y = theta + p.sample(seed=seed) + return y + + +def make_model(dim): + @hk.transform + def _mlp(method, **kwargs): + def _nn(theta, time, context, **kwargs): + ins = jnp.concatenate([theta, time, context], axis=-1) + outs = hk.nets.MLP([64, 64, dim])(ins) + return outs + + ccnf = CCNF(dim, _nn) + return ccnf(method, **kwargs) + + return _mlp + + +def run(): + y_observed = jnp.array([2.0, -2.0]) + + prior_simulator_fn, prior_logdensity_fn = prior_model_fns() + fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn + + estim = FMPE(fns, make_model(2)) + optimizer = optax.adam(1e-3) + + data, params = None, {} + for i in range(2): + data, _ = estim.simulate_data_and_possibly_append( + jr.fold_in(jr.PRNGKey(1), i), + params=params, + observable=y_observed, + data=data, + ) + params, info = estim.fit( + jr.fold_in(jr.PRNGKey(2), i), + data=data, + optimizer=optimizer, + ) + + rng_key = jr.PRNGKey(23) + post_samples, _ = estim.sample_posterior(rng_key, params, y_observed) + fig, axes = plt.subplots(2) + for i, ax in enumerate(axes): + sns.histplot(post_samples[:, i], color="darkblue", ax=ax) + ax.set_xlim([-3.0, 3.0]) + sns.despine() + plt.tight_layout() + plt.show() + + +if __name__ == "__main__": + run() diff --git a/sbijax/__init__.py b/sbijax/__init__.py index 2bf3b76..833dd2f 100644 --- a/sbijax/__init__.py +++ b/sbijax/__init__.py @@ -6,6 +6,7 @@ from sbijax._src.abc.smc_abc import SMCABC +from sbijax._src.fmpe import FMPE from sbijax._src.snass import SNASS from sbijax._src.snasss import SNASSS from sbijax._src.snl import SNL diff --git a/sbijax/_src/fmpe.py b/sbijax/_src/fmpe.py index ced7567..8e3c89c 100644 --- a/sbijax/_src/fmpe.py +++ b/sbijax/_src/fmpe.py @@ -6,10 +6,10 @@ from absl import logging from jax import numpy as jnp from jax import random as jr -from jax import scipy as jsp from tqdm import tqdm from sbijax._src._sne_base import SNE +from sbijax._src.nn.continuous_normalizing_flow import CCNF from sbijax._src.util.early_stopping import EarlyStopping @@ -29,7 +29,9 @@ def _ut(theta_t, theta, times, sigma_min): return num / denom -def _cfm_loss(params, rng_key, apply_fn, sigma_min=0.001, **batch): +def _cfm_loss( + params, rng_key, apply_fn, sigma_min=0.001, is_training=True, **batch +): theta = batch["theta"] n, p = theta.shape @@ -39,14 +41,23 @@ def _cfm_loss(params, rng_key, apply_fn, sigma_min=0.001, **batch): theta_key, rng_key = jr.split(rng_key) theta_t = _sample_theta_t(theta_key, times, theta, sigma_min) - vs = apply_fn(params, theta=theta_t, time=times, context=batch["y"]) + train_rng, rng_key = jr.split(rng_key) + vs = apply_fn( + params, + train_rng, + method="vector_field", + theta=theta_t, + time=times, + context=batch["y"], + is_training=is_training, + ) uts = _ut(theta_t, theta, times, sigma_min) loss = jnp.mean(jnp.square(vs - uts)) return loss -# pylint: disable=too-many-arguments,unused-argument +# pylint: disable=too-many-arguments,unused-argument,useless-parent-delegation class FMPE(SNE): """Flow matching posterior estimation. @@ -55,9 +66,7 @@ class FMPE(SNE): consists of functions to sample and evaluate the log-probability of a data point. The second element is a simulator function. - density_estimator: a (neural) conditional density estimator - to model the posterior distribution - num_atoms: number of atomic atoms + density_estimator: a continuous normalizing flow model Examples: >>> import distrax @@ -72,12 +81,12 @@ class FMPE(SNE): >>> snr = SNP(fns, flow) References: - .. [1] Greenberg, David, et al. "Automatic posterior transformation for - likelihood-free inference." International Conference on Machine - Learning, 2019. + .. [1] Wildberger, Jonas, et al. "Flow Matching for Scalable + Simulation-Based Inference." Advances in Neural Information + Processing Systems, 2024. """ - def __init__(self, model_fns, density_estimator): + def __init__(self, model_fns, density_estimator: CCNF): """Construct a FMPE object. Args: @@ -98,12 +107,12 @@ def fit( *, optimizer=optax.adam(0.0003), n_iter=1000, - batch_size=128, + batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs, ): - """Fit an SNP model. + """Fit the model. Args: rng_key: a jax random key @@ -144,13 +153,14 @@ def _fit_model_single_round( optimizer, n_iter, n_early_stopping_patience, - n_atoms, ): init_key, seed = jr.split(seed) params = self._init_params(init_key, **next(iter(train_iter))) state = optimizer.init(params) - loss_fn = jax.jit(partial(_cfm_loss, apply_fn=self.model.apply)) + loss_fn = jax.jit( + partial(_cfm_loss, apply_fn=self.model.apply, is_training=True) + ) @jax.jit def step(params, rng, state, **batch): @@ -175,9 +185,7 @@ def step(params, rng, state, **batch): batch["y"].shape[0] / train_iter.num_samples ) val_key, rng_key = jr.split(rng_key) - validation_loss = self._validation_loss( - val_key, params, val_iter, n_atoms - ) + validation_loss = self._validation_loss(val_key, params, val_iter) losses[i] = jnp.array([train_loss, validation_loss]) _, early_stop = early_stop.update(validation_loss) @@ -192,15 +200,23 @@ def step(params, rng, state, **batch): return best_params, losses def _init_params(self, rng_key, **init_data): + times = jr.uniform(jr.PRNGKey(0), shape=(init_data["y"].shape[0], 1)) params = self.model.init( - rng_key, y=init_data["theta"], x=init_data["y"] + rng_key, + method="vector_field", + theta=init_data["theta"], + time=times, + context=init_data["y"], + is_training=False, ) return params def _validation_loss(self, rng_key, params, val_iter): - loss_fn = jax.jit(partial(_cfm_loss, apply_fn=self.model.apply)) + loss_fn = jax.jit( + partial(_cfm_loss, apply_fn=self.model.apply, is_training=False) + ) - def body_fn(batch_key, batch): + def body_fn(batch_key, **batch): loss = loss_fn(params, batch_key, **batch) return loss * (batch["y"].shape[0] / val_iter.num_samples) @@ -238,8 +254,7 @@ def sample_posterior( params, sample_key, method="sample", - sample_shape=(n_sim,), - x=jnp.tile(observable, [n_sim, 1]), + context=jnp.tile(observable, [n_sim, 1]), ) proposal_probs = self.prior_log_density_fn(proposal) proposal_accepted = proposal[jnp.isfinite(proposal_probs)] diff --git a/sbijax/_src/nn/continuous_normalizing_flow.py b/sbijax/_src/nn/continuous_normalizing_flow.py index 2a55c4f..1c405e0 100644 --- a/sbijax/_src/nn/continuous_normalizing_flow.py +++ b/sbijax/_src/nn/continuous_normalizing_flow.py @@ -4,55 +4,92 @@ import haiku as hk import jax from jax import numpy as jnp -from jax.experimental.ode import odeint from jax.nn import glu from scipy import integrate +__all__ = ["CCNF", "make_ccnf"] + class CCNF(hk.Module): - def __init__(self, n_dimension, transform): + """Conditional continuous normalizing flow. + + Args: + n_dimension: the dimensionality of the modelled space + transform: a haiku module. The transform is a callable that has to + take as input arguments named 'theta', 'time', 'context' and + **kwargs. Theta, time and context are two-dimensional arrays + with the same batch dimensions. + """ + + def __init__(self, n_dimension: int, transform: Callable): + """Conditional continuous normalizing flow. + + Args: + n_dimension: the dimensionality of the modelled space + transform: a haiku module. The transform is a callable that has to + take as input arguments named 'theta', 'time', 'context' and + **kwargs. Theta, time and context are two-dimensional arrays + with the same batch dimensions. + """ super().__init__() self._n_dimension = n_dimension self._network = transform self._base_distribution = distrax.Normal(jnp.zeros(n_dimension), 1.0) def __call__(self, method, **kwargs): + """Aplpy the flow. + + Args: + method (str): method to call + + Keyword Args: + keyword arguments for the called method: + """ return getattr(self, method)(**kwargs) def sample(self, context): + """Sample from the pushforward. + + Args: + context: array of conditioning variables + """ theta_0 = self._base_distribution.sample( seed=hk.next_rng_key(), sample_shape=(context.shape[0],) ) - _, theta_1 = odeint( - lambda t, theta_t: self.vector_field(t, theta_t, context), - theta_0, - jnp.array([0.0, 1.0]), - atol=1e-7, - rtol=1e-7, - ) - def ode_func(t, theta_t): + def ode_func(time, theta_t): theta_t = theta_t.reshape(-1, self._n_dimension) - times = jnp.full((theta_t.shape[0],), t) + time = jnp.full((theta_t.shape[0], 1), time) ret = self.vector_field( - theta_t=theta_t, times=times, context=context + theta=theta_t, time=time, context=context, is_training=False ) return ret.reshape(-1) res = integrate.solve_ivp( ode_func, - (1, 0.00001), + (0.0, 1.0), theta_0.reshape(-1), rtol=1e-5, atol=1e-5, method="RK45", ) - return res + ret = res.y[:, -1].reshape(-1, self._n_dimension) + return ret - def vector_field(self, theta_t, times, context): - times = jnp.full((theta_t.shape[0], 1), times) - return self._network(theta=theta_t, times=times, context=context) + def vector_field(self, theta, time, context, **kwargs): + """Compute the vector field. + + Args: + theta: array of parameters + time: time variables + context: array of conditioning variables + + Keyword Args: + keyword arguments that aer passed tothe neural network + """ + time = jnp.full((theta.shape[0], 1), time) + return self._network(theta=theta, time=time, context=context, **kwargs) # pylint: disable=too-many-arguments @@ -87,13 +124,13 @@ def __call__(self, inputs, context, is_training=False): rng=hk.next_rng_key(), rate=self.dropout_rate, x=outputs ) outputs = hk.Linear(self.hidden_size)(outputs) - context_proj = hk.Linear(inputs.dimension[-1])(context) + context_proj = hk.Linear(inputs.shape[-1])(context) outputs = glu(jnp.concatenate([outputs, context_proj], axis=-1)) return outputs + inputs # pylint: disable=too-many-arguments -class _CCNF_Resnet(hk.Module): +class _CCNFResnet(hk.Module): """A simplified 1-d residual network.""" def __init__( @@ -115,12 +152,16 @@ def __init__( self.dropout_rate = dropout_rate self.batch_norm_decay = batch_norm_decay - def __call__(self, theta, times, y, is_training=False, **kwargs): - outputs = y + def __call__(self, theta, time, context, is_training=False, **kwargs): + outputs = context + # this is a bit weird, but what the paper suggests: + # instead of using times and context (i.e., y) as conditioning variables + # it suggests using times and theta and use y in the resnet blocks, + # since theta is typically low-dim and y is typically high-dime t_theta_embedding = jnp.concatenate( [ hk.Linear(self.n_dimension)(theta), - hk.Linear(self.n_dimension)(times), + hk.Linear(self.n_dimension)(time), ], axis=-1, ) @@ -150,22 +191,24 @@ def make_ccnf( ): """Create a conditional continuous normalizing flow. + The CCNF uses a residual network as transformer which is created + automatically. + Args: - n_dimension: dimensionality of theta - n_layers: number of normalizing flow layers - hidden_size: sizes of hidden layers for each normalizing flow + n_dimension: dimensionality of modelled space + n_layers: number of resnet blocks + hidden_size: sizes of hidden layers for each resnet block activation: a jax activation function dropout_rate: dropout rate to use in resnet blocks do_batch_norm: use batch normalization or not batch_norm_decay: decay rate of EMA in batch norm layer Returns: - a neural network model + returns a conditional continuous normalizing flow """ - @hk.without_apply_rng @hk.transform def _flow(method, **kwargs): - nn = _CCNF_Resnet( + nn = _CCNFResnet( n_layers=n_layers, n_dimension=n_dimension, hidden_size=hidden_size, diff --git a/sbijax/nn/__init__.py b/sbijax/nn/__init__.py index c9082f7..635b6f9 100644 --- a/sbijax/nn/__init__.py +++ b/sbijax/nn/__init__.py @@ -1,5 +1,6 @@ """Neural network module.""" +from sbijax._src.nn.continuous_normalizing_flow import CCNF, make_ccnf from sbijax._src.nn.make_flows import ( make_affine_maf, make_surjective_affine_maf,