Skip to content

Commit

Permalink
Deprecate distrax and update examples (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier authored Aug 18, 2024
1 parent 335bd17 commit 6176367
Show file tree
Hide file tree
Showing 19 changed files with 360 additions and 188 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.ipynb linguist-vendored
47 changes: 47 additions & 0 deletions .github/workflows/examples.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
name: examples

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
precommit:
name: Pre-commit checks
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3

examples:
runs-on: ubuntu-latest
needs:
- precommit
strategy:
matrix:
python-version: [3.11]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install hatch matplotlib
- name: Build package
run: |
pip install jaxlib jax
pip install .
- name: Run tests
run: |
python examples/bivariate_gaussian-smcabc.py --n-rounds 1
python examples/mixture_model-cmpe.py --n-iter 10
python examples/mixture_model-nle.py --n-iter 10
python examples/mixture_model-nle.py --n-iter 10 --use-spf
python examples/mixture_model-npe.py --n-iter 10
python examples/mixture_model-nre.py --n-iter 10
python examples/slcp-fmpe.py --n-iter 10
python examples/slcp-nass_nle.py --n-iter 10 --n-rounds 1
python examples/slcp-nass_smcabc.py --n-iter 10 --n-rounds 1
python examples/slcp-snle.py --n-iter 10 --n-rounds 1
42 changes: 0 additions & 42 deletions examples/bivariate_gaussian-sabc.py

This file was deleted.

10 changes: 7 additions & 3 deletions examples/bivariate_gaussian-smcabc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Demonstrates sequential Monte Carlo ABC on a simple bivariate Gaussian example.
"""
import argparse

import jax
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -35,18 +36,21 @@ def distance_fn(y_simulated, y_observed):
return dist


def run():
def run(n_rounds):
y_observed = jnp.array([-2.0, 1.0])

fns = prior_fn, simulator_fn

smc = SMCABC(fns, summary_fn, distance_fn)
smc_samples, _ = smc.sample_posterior(
jr.PRNGKey(1), y_observed, 10, 1000, 0.85, 500
jr.PRNGKey(1), y_observed, n_rounds=n_rounds, n_particles=1000, ess_min=500
)
plot_posterior(smc_samples)
plt.show()


if __name__ == "__main__":
run()
parser = argparse.ArgumentParser()
parser.add_argument("--n-rounds", type=int, default=10)
args = parser.parse_args()
run(args.n_rounds)
50 changes: 50 additions & 0 deletions examples/mixture_model-cmpe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Consistency model posterior estimation example.
Demonstrates CMPE on a simple mixture model.
"""
import argparse

import matplotlib.pyplot as plt
from jax import numpy as jnp, random as jr
from tensorflow_probability.substrates.jax import distributions as tfd

from sbijax import plot_posterior, CMPE
from sbijax.nn import make_cm


def prior_fn():
prior = tfd.JointDistributionNamed(dict(
theta=tfd.Normal(jnp.zeros(2), 1)
), batch_ndims=0)
return prior


def simulator_fn(seed, theta):
mean = theta["theta"].reshape(-1, 2)
n = mean.shape[0]
data_key, cat_key = jr.split(seed)
categories = tfd.Categorical(logits=jnp.zeros(2)).sample(seed=cat_key, sample_shape=(n,))
scales = jnp.array([1.0, 0.1])[categories].reshape(-1, 1)
y = tfd.Normal(mean, scales).sample(seed=data_key)
return y


def run(n_iter):
y_observed = jnp.array([-2.0, 1.0])
fns = prior_fn, simulator_fn
neural_network = make_cm(2, 64)
model = CMPE(fns, neural_network)

data, _ = model.simulate_data(jr.PRNGKey(1), n_simulations=10_000)
params, info = model.fit(jr.PRNGKey(2), data=data, n_early_stopping_patience=25, n_iter=n_iter)
inference_result, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed)

plot_posterior(inference_result)
plt.show()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--n-iter", type=int, default=10)
args = parser.parse_args()
run(args.n_iter)
14 changes: 9 additions & 5 deletions examples/mixture_model-nle.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""Neural likelihood estimation example.
Demonstrates NLE on a simple mixture model.
"""
import matplotlib.pyplot as plt
from jax import numpy as jnp, random as jr
from tensorflow_probability.substrates.jax import distributions as tfd
Expand All @@ -24,14 +27,14 @@ def simulator_fn(seed, theta):
return y


def run(use_spf):
def run(use_spf, n_iter):
y_observed = jnp.array([-2.0, 1.0])
fns = prior_fn, simulator_fn
neural_network = make_spf(2, -5.0, 5.0, n_params=3) if use_spf else make_mdn(2, 10)
neural_network = make_spf(2, -5.0, 5.0, n_params=10) if use_spf else make_mdn(2, 10)
model = NLE(fns, neural_network)

data, _ = model.simulate_data(jr.PRNGKey(1), n_simulations=10_000)
params, info = model.fit(jr.PRNGKey(2), data=data, n_early_stopping_patience=25)
params, info = model.fit(jr.PRNGKey(2), data=data, n_early_stopping_patience=25, n_iter=n_iter)
inference_result, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed)

plot_posterior(inference_result)
Expand All @@ -41,6 +44,7 @@ def run(use_spf):
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--use-spf", action="store_true", default=True)
parser.add_argument("--use-spf", action="store_true", default=False)
parser.add_argument("--n-iter", type=int, default=1_000)
args = parser.parse_args()
run(False)
run(args.use_spf, args.n_iter)
50 changes: 50 additions & 0 deletions examples/mixture_model-npe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Neural posterior estimation example.
Demonstrates NPE on a simple mixture model.
"""
import argparse

import matplotlib.pyplot as plt
from jax import numpy as jnp, random as jr
from tensorflow_probability.substrates.jax import distributions as tfd

from sbijax import plot_posterior, NPE
from sbijax.nn import make_maf


def prior_fn():
prior = tfd.JointDistributionNamed(dict(
theta=tfd.Normal(jnp.zeros(2), 1)
), batch_ndims=0)
return prior


def simulator_fn(seed, theta):
mean = theta["theta"].reshape(-1, 2)
n = mean.shape[0]
data_key, cat_key = jr.split(seed)
categories = tfd.Categorical(logits=jnp.zeros(2)).sample(seed=cat_key, sample_shape=(n,))
scales = jnp.array([1.0, 0.1])[categories].reshape(-1, 1)
y = tfd.Normal(mean, scales).sample(seed=data_key)
return y


def run(n_iter):
y_observed = jnp.array([-2.0, 1.0])
fns = prior_fn, simulator_fn
neural_network = make_maf(2)
model = NPE(fns, neural_network)

data, _ = model.simulate_data(jr.PRNGKey(1), n_simulations=10_000)
params, info = model.fit(jr.PRNGKey(2), data=data, n_early_stopping_patience=25, n_iter=n_iter)
inference_result, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed)

plot_posterior(inference_result)
plt.show()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--n-iter", type=int, default=1_000)
args = parser.parse_args()
run(args.n_iter)
50 changes: 50 additions & 0 deletions examples/mixture_model-nre.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Neural ratio estimation example.
Demonstrates NRE on a simple mixture model.
"""
import argparse

import matplotlib.pyplot as plt
from jax import numpy as jnp, random as jr
from tensorflow_probability.substrates.jax import distributions as tfd

from sbijax import plot_posterior, NRE
from sbijax.nn import make_mlp


def prior_fn():
prior = tfd.JointDistributionNamed(dict(
theta=tfd.Normal(jnp.zeros(2), 1)
), batch_ndims=0)
return prior


def simulator_fn(seed, theta):
mean = theta["theta"].reshape(-1, 2)
n = mean.shape[0]
data_key, cat_key = jr.split(seed)
categories = tfd.Categorical(logits=jnp.zeros(2)).sample(seed=cat_key, sample_shape=(n,))
scales = jnp.array([1.0, 0.1])[categories].reshape(-1, 1)
y = tfd.Normal(mean, scales).sample(seed=data_key)
return y


def run(n_iter):
y_observed = jnp.array([-2.0, 1.0])
fns = prior_fn, simulator_fn
neural_network = make_mlp()
model = NRE(fns, neural_network)

data, _ = model.simulate_data(jr.PRNGKey(1), n_simulations=10_000)
params, info = model.fit(jr.PRNGKey(2), data=data, n_early_stopping_patience=25, n_iter=n_iter)
inference_result, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed)

plot_posterior(inference_result)
plt.show()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--n-iter", type=int, default=1_000)
args = parser.parse_args()
run(args.n_iter)
Loading

0 comments on commit 6176367

Please sign in to comment.