Skip to content

Commit

Permalink
Add SRNE and remove rejection ABC (#21)
Browse files Browse the repository at this point in the history
* Add SRNE
* Remove rejection ABC
  • Loading branch information
dirmeier authored Feb 24, 2024
1 parent f2f9722 commit c7a218c
Show file tree
Hide file tree
Showing 19 changed files with 858 additions and 143 deletions.
32 changes: 21 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,29 +1,34 @@
# sbijax

[![status](http://www.repostatus.org/badges/latest/concept.svg)](http://www.repostatus.org/#concept)
[![active](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active)
[![ci](https://github.com/dirmeier/sbijax/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/sbijax/actions/workflows/ci.yaml)
[![version](https://img.shields.io/pypi/v/sbijax.svg?colorB=black&style=flat)](https://pypi.org/project/sbijax/)

> Simulation-based inference in JAX
## About

SbiJAX implements several algorithms for simulation-based inference using
[JAX](https://github.com/google/jax), [Haiku](https://github.com/deepmind/dm-haiku) and [BlackJAX](https://github.com/blackjax-devs/blackjax).
`sbijax` implements several algorithms for simulation-based inference in
[JAX](https://github.com/google/jax) using [Haiku](https://github.com/deepmind/dm-haiku),
[Distrax](https://github.com/deepmind/distrax) and [BlackJAX](https://github.com/blackjax-devs/blackjax). Specifically, `sbijax` implements

SbiJAX so far implements
- [Sequential Monte Carlo ABC](https://www.routledge.com/Handbook-of-Approximate-Bayesian-Computation/Sisson-Fan-Beaumont/p/book/9780367733728) (`SMCABC`),
- [Neural Likelihood Estimation](https://arxiv.org/abs/1805.07226) (`SNL`)
- [Surjective Neural Likelihood Estimation](https://arxiv.org/abs/2308.01054) (`SSNL`)
- [Neural Posterior Estimation C](https://arxiv.org/abs/1905.07488) (short `SNP`)
- [Contrastive Neural Ratio Estimation](https://arxiv.org/abs/2210.06170) (short `SNR`)
- [Neural Approximate Sufficient Statistics](https://arxiv.org/abs/2010.10079) (`SNASS`)
- [Neural Approximate Slice Sufficient Statistics](https://openreview.net/forum?id=jjzJ768iV1) (`SNASSS`)

- Rejection ABC (`RejectionABC`),
- Sequential Monte Carlo ABC (`SMCABC`),
- Sequential Neural Likelihood Estimation (`SNL`)
- Surjective Sequential Neural Likelihood Estimation (`SSNL`)
- Sequential Neural Posterior Estimation C (short `SNP`)
where the acronyms in parentheses denote the names of the methods in `sbijax`.

## Examples

You can find several self-contained examples on how to use the algorithms in `examples`.
You can find several self-contained examples on how to use the algorithms in [examples](https://github.com/dirmeier/sbijax/tree/main/examples).

## Usage
## Documentation

Documentation can be found [here](https://sbijax.readthedocs.io/en/latest/).

## Installation

Expand All @@ -42,6 +47,11 @@ To install the latest GitHub <RELEASE>, use:
pip install git+https://github.com/dirmeier/sbijax@<RELEASE>
```

## Acknowledgements

> 📝 The package draws significant inspiration from the excellent Pytorch-based [`sbi`](https://github.com/sbi-dev/sbi) package which is substantially more
feature-complete and user-friendly, and better documented.

## Author

Simon Dirmeier <a href="mailto:sfyrbnd @ pm me">sfyrbnd @ pm me</a>
2 changes: 1 addition & 1 deletion examples/bivariate_gaussian_snl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
from jax import numpy as jnp
from jax import random as jr
from surjectors import (
MADE,
Chain,
MaskedAutoregressive,
Permutation,
TransformedDistribution,
)
from surjectors.nn import MADE
from surjectors.util import unstack

from sbijax import SNL
Expand Down
2 changes: 1 addition & 1 deletion examples/bivariate_gaussian_snp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Example using sequential posterior estimation on a bivariate Gaussian
Example using sequential neural posterior estimation on a bivariate Gaussian
"""

import distrax
Expand Down
72 changes: 72 additions & 0 deletions examples/bivariate_gaussian_snr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
Example using sequential neural ratio estimation on a bivariate Gaussian
"""

import distrax
import haiku as hk
import matplotlib.pyplot as plt
import optax
import seaborn as sns
from jax import numpy as jnp
from jax import random as jr

from sbijax import SNR


def prior_model_fns():
p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1)
return p.sample, p.log_prob


def simulator_fn(seed, theta):
p = distrax.Normal(jnp.zeros_like(theta), 1.0)
y = theta + p.sample(seed=seed)
return y


def make_model():
@hk.without_apply_rng
@hk.transform
def _mlp(inputs, **kwargs):
return hk.nets.MLP([64, 64, 1])(inputs)

return _mlp


def run():
y_observed = jnp.array([2.0, -2.0])

prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn

snr = SNR(fns, make_model())
optimizer = optax.adam(1e-3)

data, params = None, {}
for i in range(5):
data, _ = snr.simulate_data_and_possibly_append(
jr.fold_in(jr.PRNGKey(1), i),
params=params,
observable=y_observed,
data=data,
)
params, info = snr.fit(
jr.fold_in(jr.PRNGKey(2), i),
data=data,
optimizer=optimizer,
batch_size=100,
)

rng_key = jr.PRNGKey(23)
snr_samples, _ = snr.sample_posterior(rng_key, params, y_observed)
fig, axes = plt.subplots(2)
for i, ax in enumerate(axes):
sns.histplot(snr_samples[:, i], color="darkblue", ax=ax)
ax.set_xlim([-3.0, 3.0])
sns.despine()
plt.tight_layout()
plt.show()


if __name__ == "__main__":
run()
81 changes: 63 additions & 18 deletions examples/bivariate_gaussian_snasss.py → examples/slcp_snass.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""
Example using sequential neural approximate (slice) summary statistics on a
bivariate Gaussian with repeated dimensions
Example SNASS on the SLCP experiment
"""

import distrax
Expand All @@ -11,6 +10,7 @@
import seaborn as sns
from jax import numpy as jnp
from jax import random as jr
from jax import scipy as jsp
from surjectors import (
Chain,
MaskedAutoregressive,
Expand All @@ -23,20 +23,64 @@
from sbijax import SNASSS
from sbijax.nn import make_snasss_net

W = jr.normal(jr.PRNGKey(0), (2, 10))


def prior_model_fns():
p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1)
p = distrax.Independent(
distrax.Uniform(jnp.full(5, -3.0), jnp.full(5, 3.0)), 1
)
return p.sample, p.log_prob


def simulator_fn(seed, theta):
y = theta @ W
y = y + distrax.Normal(jnp.zeros_like(y), 0.1).sample(seed=seed)
orig_shape = theta.shape
if theta.ndim == 2:
theta = theta[:, None, :]
us_key, noise_key = jr.split(seed)

def _unpack_params(ps):
m0 = ps[..., [0]]
m1 = ps[..., [1]]
s0 = ps[..., [2]] ** 2
s1 = ps[..., [3]] ** 2
r = jnp.tanh(ps[..., [4]])
return m0, m1, s0, s1, r

m0, m1, s0, s1, r = _unpack_params(theta)
us = distrax.Normal(0.0, 1.0).sample(
seed=us_key, sample_shape=(theta.shape[0], theta.shape[1], 4, 2)
)
xs = jnp.empty_like(us)
xs = xs.at[:, :, :, 0].set(s0 * us[:, :, :, 0] + m0)
y = xs.at[:, :, :, 1].set(
s1 * (r * us[:, :, :, 0] + jnp.sqrt(1.0 - r**2) * us[:, :, :, 1]) + m1
)
if len(orig_shape) == 2:
y = y.reshape((*theta.shape[:1], 8))
else:
y = y.reshape((*theta.shape[:2], 8))
return y


def likelihood_fn(theta, y):
mu = jnp.tile(theta[:2], 4)
s1, s2 = theta[2] ** 2, theta[3] ** 2
corr = s1 * s2 * jnp.tanh(theta[4])
cov = jnp.array([[s1**2, corr], [corr, s2**2]])
cov = jsp.linalg.block_diag(*[cov for _ in range(4)])
p = distrax.MultivariateNormalFullCovariance(mu, cov)
return p.log_prob(y)


def log_density_fn(theta, y):
prior_lp = distrax.Independent(
distrax.Uniform(jnp.full(5, -3.0), jnp.full(5, 3.0)), 1
).log_prob(theta)
likelihood_lp = likelihood_fn(theta, y)

lp = jnp.sum(prior_lp) + jnp.sum(likelihood_lp)
return lp


def make_model(dim):
def _bijector_fn(params):
means, log_scales = unstack(params, -1)
Expand All @@ -50,7 +94,7 @@ def _flow(method, **kwargs):
bijector_fn=_bijector_fn,
conditioner=MADE(
dim,
[50, 50, dim * 2],
[64, 64, dim * 2],
2,
w_init=hk.initializers.TruncatedNormal(0.001),
b_init=jnp.zeros,
Expand All @@ -75,38 +119,39 @@ def _flow(method, **kwargs):


def run():
y_observed = jnp.array([[2.0, -2.0]]) @ W
thetas = jnp.linspace(-2.0, 2.0, 5)
y_0 = simulator_fn(jr.PRNGKey(0), thetas.reshape(-1, 5)).reshape(-1, 8)

prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn

estim = SNASSS(
fns,
make_model(2),
make_snasss_net([64, 64, 2], [64, 64, 1], [64, 64, 1]),
make_model(5),
make_snasss_net([64, 64, 5], [64, 64, 1], [64, 64, 1]),
)
optimizer = optax.adam(1e-3)

data, params = None, {}
for i in range(2):
for i in range(5):
data, _ = estim.simulate_data_and_possibly_append(
jr.fold_in(jr.PRNGKey(1), i),
jr.fold_in(jr.PRNGKey(12), i),
params=params,
observable=y_observed,
observable=y_0,
data=data,
)
params, _ = estim.fit(
jr.fold_in(jr.PRNGKey(2), i),
jr.fold_in(jr.PRNGKey(23), i),
data=data,
optimizer=optimizer,
batch_size=100,
)

rng_key = jr.PRNGKey(23)
snp_samples, _ = estim.sample_posterior(rng_key, params, y_observed)
fig, axes = plt.subplots(2)
snasss_samples, _ = estim.sample_posterior(rng_key, params, y_0)
fig, axes = plt.subplots(5)
for i, ax in enumerate(axes):
sns.histplot(snp_samples[:, i], color="darkblue", ax=ax)
sns.histplot(snasss_samples[:, i], color="darkblue", ax=ax)
ax.set_xlim([-3.0, 3.0])
sns.despine()
plt.tight_layout()
Expand Down
15 changes: 5 additions & 10 deletions examples/slcp_ssnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@

import argparse
from functools import partial
from timeit import default_timer as timer

import distrax
import haiku as hk
import jax
import matplotlib.pyplot as plt
import numpy as np
import optax
import pandas as pd
import seaborn as sns
Expand All @@ -25,11 +23,11 @@
Permutation,
TransformedDistribution,
)
from surjectors.conditioners import MADE, mlp_conditioner
from surjectors.nn import MADE, make_mlp
from surjectors.util import unstack

from sbijax import SNL
from sbijax.mcmc.slice import sample_with_slice
from sbijax.mcmc import sample_with_slice


def prior_model_fns():
Expand All @@ -50,7 +48,7 @@ def _unpack_params(ps):
m1 = ps[..., [1]]
s0 = ps[..., [2]] ** 2
s1 = ps[..., [3]] ** 2
r = np.tanh(ps[..., [4]])
r = jnp.tanh(ps[..., [4]])
return m0, m1, s0, s1, r

m0, m1, s0, s1, r = _unpack_params(theta)
Expand All @@ -60,7 +58,7 @@ def _unpack_params(ps):
xs = jnp.empty_like(us)
xs = xs.at[:, :, :, 0].set(s0 * us[:, :, :, 0] + m0)
y = xs.at[:, :, :, 1].set(
s1 * (r * us[:, :, :, 0] + np.sqrt(1.0 - r**2) * us[:, :, :, 1]) + m1
s1 * (r * us[:, :, :, 0] + jnp.sqrt(1.0 - r**2) * us[:, :, :, 1]) + m1
)
if len(orig_shape) == 2:
y = y.reshape((*theta.shape[:1], 8))
Expand Down Expand Up @@ -95,7 +93,7 @@ def _bijector_fn(params):
return distrax.ScalarAffine(means, jnp.exp(log_scales))

def _decoder_fn(n_dim):
decoder_net = mlp_conditioner(
decoder_net = make_mlp(
[50, n_dim * 2],
w_init=hk.initializers.TruncatedNormal(stddev=0.001),
)
Expand Down Expand Up @@ -189,7 +187,6 @@ def run(use_surjectors):
optimizer = optax.adam(1e-3)

data, params = None, {}
start = timer()
for i in range(5):
data, _ = snl.simulate_data_and_possibly_append(
jr.fold_in(jr.PRNGKey(12), i),
Expand All @@ -201,8 +198,6 @@ def run(use_surjectors):
params, info = snl.fit(
jr.fold_in(jr.PRNGKey(23), i), data=data, optimizer=optimizer
)
end = timer()
print(end - start)

sample_key, rng_key = jr.split(jr.PRNGKey(123))
snl_samples, _ = snl.sample_posterior(sample_key, params, y_observed)
Expand Down
10 changes: 8 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ dependencies = [
"dm-haiku>=0.0.9",
"optax>=0.1.3",
"surjectors>=0.3.0",
"tfp-nightly>=0.20.0.dev20230404"
"tfp-nightly>=0.20.0.dev20230404",
"tqdm>=4.64.1"
]
dynamic = ["version"]

Expand All @@ -46,6 +47,12 @@ exclude = [
"/.pre-commit-config.yaml"
]

[tool.hatch.envs.examples]
dependencies = [
"matplotlib>=3.6.1",
"seaborn>=0.12.2"
]

[tool.hatch.envs.test]
dependencies = [
"pylint>=2.15.10",
Expand All @@ -57,7 +64,6 @@ dependencies = [
lint = 'pylint sbijax'
test = 'pytest -v --doctest-modules --cov=./sbijax --cov-report=xml sbijax'


[tool.black]
line-length = 80
target-version = ['py311']
Expand Down
Loading

0 comments on commit c7a218c

Please sign in to comment.