-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Deprecate distrax and update examples (#48)
- Loading branch information
Showing
19 changed files
with
360 additions
and
188 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
*.ipynb linguist-vendored |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
name: examples | ||
|
||
on: | ||
push: | ||
branches: [ main ] | ||
pull_request: | ||
branches: [ main ] | ||
|
||
jobs: | ||
precommit: | ||
name: Pre-commit checks | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v3 | ||
|
||
examples: | ||
runs-on: ubuntu-latest | ||
needs: | ||
- precommit | ||
strategy: | ||
matrix: | ||
python-version: [3.11] | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: Install dependencies | ||
run: | | ||
pip install hatch matplotlib | ||
- name: Build package | ||
run: | | ||
pip install jaxlib jax | ||
pip install . | ||
- name: Run tests | ||
run: | | ||
python examples/bivariate_gaussian-smcabc.py --n-rounds 1 | ||
python examples/mixture_model-cmpe.py --n-iter 10 | ||
python examples/mixture_model-nle.py --n-iter 10 | ||
python examples/mixture_model-nle.py --n-iter 10 --use-spf | ||
python examples/mixture_model-npe.py --n-iter 10 | ||
python examples/mixture_model-nre.py --n-iter 10 | ||
python examples/slcp-fmpe.py --n-iter 10 | ||
python examples/slcp-nass_nle.py --n-iter 10 --n-rounds 1 | ||
python examples/slcp-nass_smcabc.py --n-iter 10 --n-rounds 1 | ||
python examples/slcp-snle.py --n-iter 10 --n-rounds 1 |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
"""Consistency model posterior estimation example. | ||
Demonstrates CMPE on a simple mixture model. | ||
""" | ||
import argparse | ||
|
||
import matplotlib.pyplot as plt | ||
from jax import numpy as jnp, random as jr | ||
from tensorflow_probability.substrates.jax import distributions as tfd | ||
|
||
from sbijax import plot_posterior, CMPE | ||
from sbijax.nn import make_cm | ||
|
||
|
||
def prior_fn(): | ||
prior = tfd.JointDistributionNamed(dict( | ||
theta=tfd.Normal(jnp.zeros(2), 1) | ||
), batch_ndims=0) | ||
return prior | ||
|
||
|
||
def simulator_fn(seed, theta): | ||
mean = theta["theta"].reshape(-1, 2) | ||
n = mean.shape[0] | ||
data_key, cat_key = jr.split(seed) | ||
categories = tfd.Categorical(logits=jnp.zeros(2)).sample(seed=cat_key, sample_shape=(n,)) | ||
scales = jnp.array([1.0, 0.1])[categories].reshape(-1, 1) | ||
y = tfd.Normal(mean, scales).sample(seed=data_key) | ||
return y | ||
|
||
|
||
def run(n_iter): | ||
y_observed = jnp.array([-2.0, 1.0]) | ||
fns = prior_fn, simulator_fn | ||
neural_network = make_cm(2, 64) | ||
model = CMPE(fns, neural_network) | ||
|
||
data, _ = model.simulate_data(jr.PRNGKey(1), n_simulations=10_000) | ||
params, info = model.fit(jr.PRNGKey(2), data=data, n_early_stopping_patience=25, n_iter=n_iter) | ||
inference_result, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed) | ||
|
||
plot_posterior(inference_result) | ||
plt.show() | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--n-iter", type=int, default=10) | ||
args = parser.parse_args() | ||
run(args.n_iter) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
"""Neural posterior estimation example. | ||
Demonstrates NPE on a simple mixture model. | ||
""" | ||
import argparse | ||
|
||
import matplotlib.pyplot as plt | ||
from jax import numpy as jnp, random as jr | ||
from tensorflow_probability.substrates.jax import distributions as tfd | ||
|
||
from sbijax import plot_posterior, NPE | ||
from sbijax.nn import make_maf | ||
|
||
|
||
def prior_fn(): | ||
prior = tfd.JointDistributionNamed(dict( | ||
theta=tfd.Normal(jnp.zeros(2), 1) | ||
), batch_ndims=0) | ||
return prior | ||
|
||
|
||
def simulator_fn(seed, theta): | ||
mean = theta["theta"].reshape(-1, 2) | ||
n = mean.shape[0] | ||
data_key, cat_key = jr.split(seed) | ||
categories = tfd.Categorical(logits=jnp.zeros(2)).sample(seed=cat_key, sample_shape=(n,)) | ||
scales = jnp.array([1.0, 0.1])[categories].reshape(-1, 1) | ||
y = tfd.Normal(mean, scales).sample(seed=data_key) | ||
return y | ||
|
||
|
||
def run(n_iter): | ||
y_observed = jnp.array([-2.0, 1.0]) | ||
fns = prior_fn, simulator_fn | ||
neural_network = make_maf(2) | ||
model = NPE(fns, neural_network) | ||
|
||
data, _ = model.simulate_data(jr.PRNGKey(1), n_simulations=10_000) | ||
params, info = model.fit(jr.PRNGKey(2), data=data, n_early_stopping_patience=25, n_iter=n_iter) | ||
inference_result, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed) | ||
|
||
plot_posterior(inference_result) | ||
plt.show() | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--n-iter", type=int, default=1_000) | ||
args = parser.parse_args() | ||
run(args.n_iter) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
"""Neural ratio estimation example. | ||
Demonstrates NRE on a simple mixture model. | ||
""" | ||
import argparse | ||
|
||
import matplotlib.pyplot as plt | ||
from jax import numpy as jnp, random as jr | ||
from tensorflow_probability.substrates.jax import distributions as tfd | ||
|
||
from sbijax import plot_posterior, NRE | ||
from sbijax.nn import make_mlp | ||
|
||
|
||
def prior_fn(): | ||
prior = tfd.JointDistributionNamed(dict( | ||
theta=tfd.Normal(jnp.zeros(2), 1) | ||
), batch_ndims=0) | ||
return prior | ||
|
||
|
||
def simulator_fn(seed, theta): | ||
mean = theta["theta"].reshape(-1, 2) | ||
n = mean.shape[0] | ||
data_key, cat_key = jr.split(seed) | ||
categories = tfd.Categorical(logits=jnp.zeros(2)).sample(seed=cat_key, sample_shape=(n,)) | ||
scales = jnp.array([1.0, 0.1])[categories].reshape(-1, 1) | ||
y = tfd.Normal(mean, scales).sample(seed=data_key) | ||
return y | ||
|
||
|
||
def run(n_iter): | ||
y_observed = jnp.array([-2.0, 1.0]) | ||
fns = prior_fn, simulator_fn | ||
neural_network = make_mlp() | ||
model = NRE(fns, neural_network) | ||
|
||
data, _ = model.simulate_data(jr.PRNGKey(1), n_simulations=10_000) | ||
params, info = model.fit(jr.PRNGKey(2), data=data, n_early_stopping_patience=25, n_iter=n_iter) | ||
inference_result, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed) | ||
|
||
plot_posterior(inference_result) | ||
plt.show() | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--n-iter", type=int, default=1_000) | ||
args = parser.parse_args() | ||
run(args.n_iter) |
Oops, something went wrong.