Skip to content

Commit

Permalink
Fix tutorials (#13)
Browse files Browse the repository at this point in the history
* Fix tutorials
* Fix test set
* Minuscule changes/improvements
* Increment version
  • Loading branch information
dirmeier authored Aug 21, 2023
1 parent 76acb19 commit 8f42e19
Show file tree
Hide file tree
Showing 16 changed files with 252 additions and 401 deletions.
74 changes: 0 additions & 74 deletions examples/abc_bivariate_gaussian.py

This file was deleted.

21 changes: 9 additions & 12 deletions examples/bivariate_gaussian_smcabc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,14 @@
from sbijax import SMCABC


def prior_model_fns(leng):
p = distrax.Independent(
distrax.Uniform(jnp.full(leng, -2.0), jnp.full(leng, 2.0)), 1
)
def prior_model_fns():
p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1)
return p.sample, p.log_prob


def simulator_fn(seed, theta):
p = distrax.MultivariateNormalDiag(theta, 0.1 * jnp.ones_like(theta))
y = p.sample(seed=seed)
p = distrax.Normal(jnp.zeros_like(theta), 0.1)
y = theta + p.sample(seed=seed)
return y


Expand All @@ -38,18 +36,17 @@ def distance_fn(y_simulated, y_observed):


def run():
len_thetas = 2
y_observed = jnp.ones(len_thetas)
y_observed = jnp.array([-2.0, 1.0])

prior_simulator_fn, prior_logdensity_fn = prior_model_fns(len_thetas)
prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn

smc = SMCABC(fns, summary_fn, distance_fn)
smc.fit(23, y_observed)
smc_samples, _ = smc.sample_posterior(10, 1000, 1000, 0.75, 500)
smc_samples, _ = smc.sample_posterior(10, 1000, 1000, 0.8, 500)

fig, axes = plt.subplots(len_thetas)
for i in range(len_thetas):
fig, axes = plt.subplots(2)
for i in range(2):
sns.histplot(smc_samples[:, i], color="darkblue", ax=axes[i])
axes[i].set_title(rf"Approximated posterior $\theta_{i}$")
sns.despine()
Expand Down
77 changes: 46 additions & 31 deletions examples/bivariate_gaussian_snl.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,43 @@
"""
Example using SNL and masked coupling flows
Example using SNL and masked autoregressive flows flows
"""

from functools import partial

import distrax
import haiku as hk
import jax
import matplotlib.pyplot as plt
import optax
import seaborn as sns
from jax import numpy as jnp
from jax import random
from surjectors import Chain, MaskedCoupling, TransformedDistribution
from surjectors.conditioners import mlp_conditioner
from surjectors.util import make_alternating_binary_mask
from surjectors import (
Chain,
MaskedAutoregressive,
Permutation,
TransformedDistribution,
)
from surjectors.conditioners import MADE
from surjectors.util import unstack

from sbijax import SNL
from sbijax.mcmc import sample_with_nuts
from sbijax.mcmc import sample_with_slice


def prior_model_fns():
p = distrax.Independent(
distrax.Uniform(jnp.full(2, -3.0), jnp.full(2, 3.0)), 1
)
p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1)
return p.sample, p.log_prob


def simulator_fn(seed, theta):
p = distrax.MultivariateNormalDiag(theta, 0.1 * jnp.ones_like(theta))
y = p.sample(seed=seed)
p = distrax.Normal(jnp.zeros_like(theta), 0.1)
y = theta + p.sample(seed=seed)
return y


def log_density_fn(theta, y):
prior = distrax.Uniform(jnp.full(2, -3.0), jnp.full(2, 3.0))
prior = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1)
likelihood = distrax.MultivariateNormalDiag(
theta, 0.1 * jnp.ones_like(theta)
)
Expand All @@ -44,19 +48,27 @@ def log_density_fn(theta, y):

def make_model(dim):
def _bijector_fn(params):
means, log_scales = jnp.split(params, 2, -1)
means, log_scales = unstack(params, -1)
return distrax.ScalarAffine(means, jnp.exp(log_scales))

def _flow(method, **kwargs):
layers = []
for i in range(2):
mask = make_alternating_binary_mask(dim, i % 2 == 0)
layer = MaskedCoupling(
mask=mask,
bijector=_bijector_fn,
conditioner=mlp_conditioner([8, 8, dim * 2]),
order = jnp.arange(dim)
for i in range(5):
layer = MaskedAutoregressive(
bijector_fn=_bijector_fn,
conditioner=MADE(
dim,
[50, dim * 2],
2,
w_init=hk.initializers.TruncatedNormal(0.001),
b_init=jnp.zeros,
activation=jax.nn.tanh,
),
)
order = order[::-1]
layers.append(layer)
layers.append(Permutation(order, 1))
chain = Chain(layers)

base_distribution = distrax.Independent(
Expand All @@ -72,11 +84,10 @@ def _flow(method, **kwargs):


def run():
rng_seq = hk.PRNGSequence(0)
y_observed = jnp.array([-1.0, 1.0])
y_observed = jnp.array([-2.0, 1.0])

log_density_partial = partial(log_density_fn, y=y_observed)
log_density = lambda x: log_density_partial(**x)
log_density = lambda x: jax.vmap(log_density_partial)(x)

prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn
Expand All @@ -86,29 +97,33 @@ def run():
params, info = snl.fit(
random.PRNGKey(23),
y_observed,
n_rounds=1,
optimizer=optimizer,
n_rounds=3,
max_n_iter=100,
batch_size=64,
n_early_stopping_patience=5,
sampler="slice",
)

nuts_samples = sample_with_nuts(rng_seq, log_density, 2, 4, 2000, 1000)
slice_samples = sample_with_slice(
hk.PRNGSequence(0), log_density, 4, 2000, 1000, prior_simulator_fn
)
slice_samples = slice_samples.reshape(-1, 2)
snl_samples, _ = snl.sample_posterior(
params, 4, 10000, 7500, sampler="slice"
params, 4, 2000, 1000, sampler="slice"
)

snl_samples = snl_samples.reshape(-1, 2)
nuts_samples = nuts_samples.reshape(-1, 2)

print(f"Took n={snl.n_total_simulations} simulations in total")
fig, axes = plt.subplots(2, 2)
for i in range(2):
sns.histplot(nuts_samples[:, i], color="darkgrey", ax=axes.flatten()[i])
sns.histplot(
slice_samples[:, i], color="darkgrey", ax=axes.flatten()[i]
)
sns.histplot(
snl_samples[:, i], color="darkblue", ax=axes.flatten()[i + 2]
)
axes.flatten()[i].set_title(rf"Sampled posterior $\theta_{i}$")
axes.flatten()[i + 2].set_title(
rf"Approximated posterior $\theta_{i + 2}$"
)
axes.flatten()[i + 2].set_title(rf"Approximated posterior $\theta_{i}$")
sns.despine()
plt.tight_layout()
plt.show()
Expand Down
19 changes: 9 additions & 10 deletions examples/bivariate_gaussian_snp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Example using SNL and masked coupling flows
Example using SNP and masked autoregressive flows
"""

import distrax
Expand All @@ -20,9 +20,7 @@


def prior_model_fns():
p = distrax.Independent(
distrax.Uniform(-2 * jnp.ones(2), 2 * jnp.ones(2)), 1
)
p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1)
return p.sample, p.log_prob


Expand All @@ -45,7 +43,7 @@ def _flow(method, **kwargs):
bijector_fn=_bijector_fn,
conditioner=MADE(
dim,
[50, 50, dim * 2],
[50, dim * 2],
2,
w_init=hk.initializers.TruncatedNormal(0.001),
b_init=jnp.zeros,
Expand Down Expand Up @@ -74,24 +72,25 @@ def run():
prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn

optimizer = optax.chain(optax.clip(5.0), optax.adamw(1e-04))
optimizer = optax.adamw(1e-04)
snp = SNP(fns, make_model(2))
params, info = snp.fit(
random.PRNGKey(2),
y_observed,
n_rounds=5,
n_rounds=3,
optimizer=optimizer,
n_early_stopping_patience=10,
batch_size=128,
batch_size=64,
n_atoms=10,
max_iter=200,
max_n_iter=100,
)

print(f"Took n={snp.n_total_simulations} simulations in total")
snp_samples, _ = snp.sample_posterior(params, 10000)
fig, axes = plt.subplots(2)
for i, ax in enumerate(axes):
sns.histplot(snp_samples[:, i], color="darkblue", ax=ax)
ax.set_xlim([-2.0, 2.0])
ax.set_xlim([-3.0, 3.0])
sns.despine()
plt.tight_layout()
plt.show()
Expand Down
18 changes: 9 additions & 9 deletions examples/slcp_smcabc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from jax import numpy as jnp
from jax import random
from jax import scipy as jsp
from jax import vmap

from sbijax import SMCABC
from sbijax.mcmc import sample_with_nuts
from sbijax.mcmc import sample_with_slice


def prior_model_fns():
Expand Down Expand Up @@ -102,7 +103,7 @@ def run():

smc = SMCABC(fns, summary_fn, distance_fn)
smc.fit(23, y_observed)
smc_samples = smc.sample_posterior(5, 1000, 10, 0.9, 500)
smc_samples, _ = smc.sample_posterior(5, 1000, 10, 0.9, 500)

def log_density_fn(theta, y):
prior_lp = prior_logdensity_fn(theta)
Expand All @@ -112,15 +113,14 @@ def log_density_fn(theta, y):
return lp

log_density_partial = partial(log_density_fn, y=y_observed)
log_density = lambda x: log_density_partial(**x)
log_density = lambda x: vmap(log_density_partial)(x)

rng_seq = hk.PRNGSequence(12)
nuts_samples = sample_with_nuts(
rng_seq, log_density, len_theta, 4, 20000, 5000
slice_samples = sample_with_slice(
hk.PRNGSequence(12), log_density, 4, 10000, 5000, prior_simulator_fn
)
nuts_samples = nuts_samples.reshape(-1, len_theta)
slice_samples = slice_samples.reshape(-1, len_theta)

g = sns.PairGrid(pd.DataFrame(nuts_samples))
g = sns.PairGrid(pd.DataFrame(slice_samples))
g.map_upper(sns.scatterplot, color="black", marker=".", edgecolor=None, s=2)
g.map_diag(plt.hist, color="black")
for ax in g.axes.flatten():
Expand All @@ -132,7 +132,7 @@ def log_density_fn(theta, y):

fig, axes = plt.subplots(len_theta, 2)
for i in range(len_theta):
sns.histplot(nuts_samples[:, i], color="darkgrey", ax=axes[i, 0])
sns.histplot(slice_samples[:, i], color="darkgrey", ax=axes[i, 0])
sns.histplot(smc_samples[:, i], color="darkblue", ax=axes[i, 1])
axes[i, 0].set_title(rf"Sampled posterior $\theta_{i}$")
axes[i, 1].set_title(rf"Approximated posterior $\theta_{i}$")
Expand Down
Loading

0 comments on commit 8f42e19

Please sign in to comment.