Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier committed Aug 18, 2024
1 parent a4ec60b commit 77003fa
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 37 deletions.
2 changes: 1 addition & 1 deletion .gitattributes
Original file line number Diff line number Diff line change
@@ -1 +1 @@
*.ipynb linguist-documentation
*.ipynb linguist-vendored
49 changes: 13 additions & 36 deletions examples/slcp-snle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@
likelihood complex posterior model.
"""
import argparse
from functools import partial

import distrax
import haiku as hk
import jax
import matplotlib.pyplot as plt
import optax
import pandas as pd
import seaborn as sns
from jax import numpy as jnp
from jax import random as jr
from jax import scipy as jsp
Expand All @@ -28,7 +25,6 @@
from tensorflow_probability.substrates.jax import distributions as tfd

from sbijax import SNLE, inference_data_as_dictionary
from sbijax.mcmc import sample_with_nuts
from sbijax.nn import make_maf


Expand Down Expand Up @@ -117,7 +113,7 @@ def _flow(method, **kwargs):
_decoder_fn(n_dimension - n_latent),
conditioner=MADE(
n_latent,
[50, n_latent * 2],
[64, 64],
2,
w_init=hk.initializers.TruncatedNormal(0.001),
b_init=jnp.zeros,
Expand All @@ -132,7 +128,7 @@ def _flow(method, **kwargs):
bijector_fn=_bijector_fn,
conditioner=MADE(
n_dimension,
[50, n_dimension * 2],
[64, 64],
2,
w_init=hk.initializers.TruncatedNormal(0.001),
b_init=jnp.zeros,
Expand Down Expand Up @@ -188,37 +184,18 @@ def run(n_rounds, n_iter):
sample_key, rng_key = jr.split(jr.PRNGKey(3))
inference_results, _ = snl.sample_posterior(sample_key, params, y_obs)

sample_key, rng_key = jr.split(rng_key)
log_density_partial = partial(log_density_fn, y=y_obs)
log_density = lambda x: log_density_partial(**x)
slice_samples = sample_with_nuts(
sample_key,
log_density,
prior_fn().sample,
)
slice_samples = slice_samples['theta'].reshape(-1, 5)
snl_samples = inference_data_as_dictionary(inference_results.posterior)["theta"]

g = sns.PairGrid(pd.DataFrame(slice_samples))
g.map_upper(sns.scatterplot, color="black", marker=".", edgecolor=None, s=2)
g.map_diag(plt.hist, color="black")
for ax in g.axes.flatten():
ax.set_xlim(-5, 5)
ax.set_ylim(-5, 5)
g.fig.set_figheight(5)
g.fig.set_figwidth(5)
plt.show()

fig, axes = plt.subplots(5, 2)
samples = inference_data_as_dictionary(inference_results.posterior)["theta"]
_, axes = plt.subplots(figsize=(12, 10), nrows=5, ncols=5)
for i in range(0, 5):
for j in range(0, 5):
ax = axes[i, j]
if i < j:
ax.axis('off')
else:
ax.hexbin(samples[..., j], samples[..., i], gridsize=50,
bins='log')
for i in range(5):
sns.histplot(slice_samples[:, i], color="darkgrey", ax=axes[i, 0])
sns.histplot(snl_samples[:, i], color="darkblue", ax=axes[i, 1])
axes[i, 0].set_title(rf"Sampled posterior $\theta_{i}$")
axes[i, 1].set_title(rf"Approximated posterior $\theta_{i}$")
for j in range(2):
axes[i, j].set_xlim(-5, 5)
sns.despine()
plt.tight_layout()
axes[i, i].hist(samples[..., i], color="black")
plt.show()


Expand Down

0 comments on commit 77003fa

Please sign in to comment.