Skip to content

Commit

Permalink
add fome
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier committed Feb 28, 2024
1 parent 483ba92 commit 7f2199f
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 51 deletions.
77 changes: 77 additions & 0 deletions examples/bivariate_gaussian_fmpe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
Example using flow matching posterior estimation on a bivariate Gaussian
"""

import distrax
import haiku as hk
import matplotlib.pyplot as plt
import optax
import seaborn as sns
from jax import numpy as jnp
from jax import random as jr

from sbijax import FMPE
from sbijax.nn import CCNF


def prior_model_fns():
p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1)
return p.sample, p.log_prob


def simulator_fn(seed, theta):
p = distrax.Normal(jnp.zeros_like(theta), 1.0)
y = theta + p.sample(seed=seed)
return y


def make_model(dim):
@hk.transform
def _mlp(method, **kwargs):
def _nn(theta, time, context, **kwargs):
ins = jnp.concatenate([theta, time, context], axis=-1)
outs = hk.nets.MLP([64, 64, dim])(ins)
return outs

ccnf = CCNF(dim, _nn)
return ccnf(method, **kwargs)

return _mlp


def run():
y_observed = jnp.array([2.0, -2.0])

prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn

estim = FMPE(fns, make_model(2))
optimizer = optax.adam(1e-3)

data, params = None, {}
for i in range(2):
data, _ = estim.simulate_data_and_possibly_append(
jr.fold_in(jr.PRNGKey(1), i),
params=params,
observable=y_observed,
data=data,
)
params, info = estim.fit(
jr.fold_in(jr.PRNGKey(2), i),
data=data,
optimizer=optimizer,
)

rng_key = jr.PRNGKey(23)
post_samples, _ = estim.sample_posterior(rng_key, params, y_observed)
fig, axes = plt.subplots(2)
for i, ax in enumerate(axes):
sns.histplot(post_samples[:, i], color="darkblue", ax=ax)
ax.set_xlim([-3.0, 3.0])
sns.despine()
plt.tight_layout()
plt.show()


if __name__ == "__main__":
run()
1 change: 1 addition & 0 deletions sbijax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


from sbijax._src.abc.smc_abc import SMCABC
from sbijax._src.fmpe import FMPE
from sbijax._src.snass import SNASS
from sbijax._src.snasss import SNASSS
from sbijax._src.snl import SNL
Expand Down
61 changes: 38 additions & 23 deletions sbijax/_src/fmpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from absl import logging
from jax import numpy as jnp
from jax import random as jr
from jax import scipy as jsp
from tqdm import tqdm

from sbijax._src._sne_base import SNE
from sbijax._src.nn.continuous_normalizing_flow import CCNF
from sbijax._src.util.early_stopping import EarlyStopping


Expand All @@ -29,7 +29,9 @@ def _ut(theta_t, theta, times, sigma_min):
return num / denom


def _cfm_loss(params, rng_key, apply_fn, sigma_min=0.001, **batch):
def _cfm_loss(
params, rng_key, apply_fn, sigma_min=0.001, is_training=True, **batch
):
theta = batch["theta"]
n, p = theta.shape

Expand All @@ -39,14 +41,23 @@ def _cfm_loss(params, rng_key, apply_fn, sigma_min=0.001, **batch):
theta_key, rng_key = jr.split(rng_key)
theta_t = _sample_theta_t(theta_key, times, theta, sigma_min)

vs = apply_fn(params, theta=theta_t, time=times, context=batch["y"])
train_rng, rng_key = jr.split(rng_key)
vs = apply_fn(
params,
train_rng,
method="vector_field",
theta=theta_t,
time=times,
context=batch["y"],
is_training=is_training,
)
uts = _ut(theta_t, theta, times, sigma_min)

loss = jnp.mean(jnp.square(vs - uts))
return loss


# pylint: disable=too-many-arguments,unused-argument
# pylint: disable=too-many-arguments,unused-argument,useless-parent-delegation
class FMPE(SNE):
"""Flow matching posterior estimation.
Expand All @@ -55,9 +66,7 @@ class FMPE(SNE):
consists of functions to sample and evaluate the
log-probability of a data point. The second element is a
simulator function.
density_estimator: a (neural) conditional density estimator
to model the posterior distribution
num_atoms: number of atomic atoms
density_estimator: a continuous normalizing flow model
Examples:
>>> import distrax
Expand All @@ -72,12 +81,12 @@ class FMPE(SNE):
>>> snr = SNP(fns, flow)
References:
.. [1] Greenberg, David, et al. "Automatic posterior transformation for
likelihood-free inference." International Conference on Machine
Learning, 2019.
.. [1] Wildberger, Jonas, et al. "Flow Matching for Scalable
Simulation-Based Inference." Advances in Neural Information
Processing Systems, 2024.
"""

def __init__(self, model_fns, density_estimator):
def __init__(self, model_fns, density_estimator: CCNF):
"""Construct a FMPE object.
Args:
Expand All @@ -98,12 +107,12 @@ def fit(
*,
optimizer=optax.adam(0.0003),
n_iter=1000,
batch_size=128,
batch_size=100,
percentage_data_as_validation_set=0.1,
n_early_stopping_patience=10,
**kwargs,
):
"""Fit an SNP model.
"""Fit the model.
Args:
rng_key: a jax random key
Expand Down Expand Up @@ -144,13 +153,14 @@ def _fit_model_single_round(
optimizer,
n_iter,
n_early_stopping_patience,
n_atoms,
):
init_key, seed = jr.split(seed)
params = self._init_params(init_key, **next(iter(train_iter)))
state = optimizer.init(params)

loss_fn = jax.jit(partial(_cfm_loss, apply_fn=self.model.apply))
loss_fn = jax.jit(
partial(_cfm_loss, apply_fn=self.model.apply, is_training=True)
)

@jax.jit
def step(params, rng, state, **batch):
Expand All @@ -175,9 +185,7 @@ def step(params, rng, state, **batch):
batch["y"].shape[0] / train_iter.num_samples
)
val_key, rng_key = jr.split(rng_key)
validation_loss = self._validation_loss(
val_key, params, val_iter, n_atoms
)
validation_loss = self._validation_loss(val_key, params, val_iter)
losses[i] = jnp.array([train_loss, validation_loss])

_, early_stop = early_stop.update(validation_loss)
Expand All @@ -192,15 +200,23 @@ def step(params, rng, state, **batch):
return best_params, losses

def _init_params(self, rng_key, **init_data):
times = jr.uniform(jr.PRNGKey(0), shape=(init_data["y"].shape[0], 1))
params = self.model.init(
rng_key, y=init_data["theta"], x=init_data["y"]
rng_key,
method="vector_field",
theta=init_data["theta"],
time=times,
context=init_data["y"],
is_training=False,
)
return params

def _validation_loss(self, rng_key, params, val_iter):
loss_fn = jax.jit(partial(_cfm_loss, apply_fn=self.model.apply))
loss_fn = jax.jit(
partial(_cfm_loss, apply_fn=self.model.apply, is_training=False)
)

def body_fn(batch_key, batch):
def body_fn(batch_key, **batch):
loss = loss_fn(params, batch_key, **batch)
return loss * (batch["y"].shape[0] / val_iter.num_samples)

Expand Down Expand Up @@ -238,8 +254,7 @@ def sample_posterior(
params,
sample_key,
method="sample",
sample_shape=(n_sim,),
x=jnp.tile(observable, [n_sim, 1]),
context=jnp.tile(observable, [n_sim, 1]),
)
proposal_probs = self.prior_log_density_fn(proposal)
proposal_accepted = proposal[jnp.isfinite(proposal_probs)]
Expand Down
Loading

0 comments on commit 7f2199f

Please sign in to comment.