From 77003fa962badc47244595d32ab9784357e7f0f4 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Sun, 18 Aug 2024 15:31:43 +0200 Subject: [PATCH] Fix --- .gitattributes | 2 +- examples/slcp-snle.py | 49 ++++++++++++------------------------------- 2 files changed, 14 insertions(+), 37 deletions(-) diff --git a/.gitattributes b/.gitattributes index 2f77e91..5be91f9 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1 +1 @@ -*.ipynb linguist-documentation +*.ipynb linguist-vendored diff --git a/examples/slcp-snle.py b/examples/slcp-snle.py index 980ea42..df751be 100644 --- a/examples/slcp-snle.py +++ b/examples/slcp-snle.py @@ -4,15 +4,12 @@ likelihood complex posterior model. """ import argparse -from functools import partial import distrax import haiku as hk import jax import matplotlib.pyplot as plt import optax -import pandas as pd -import seaborn as sns from jax import numpy as jnp from jax import random as jr from jax import scipy as jsp @@ -28,7 +25,6 @@ from tensorflow_probability.substrates.jax import distributions as tfd from sbijax import SNLE, inference_data_as_dictionary -from sbijax.mcmc import sample_with_nuts from sbijax.nn import make_maf @@ -117,7 +113,7 @@ def _flow(method, **kwargs): _decoder_fn(n_dimension - n_latent), conditioner=MADE( n_latent, - [50, n_latent * 2], + [64, 64], 2, w_init=hk.initializers.TruncatedNormal(0.001), b_init=jnp.zeros, @@ -132,7 +128,7 @@ def _flow(method, **kwargs): bijector_fn=_bijector_fn, conditioner=MADE( n_dimension, - [50, n_dimension * 2], + [64, 64], 2, w_init=hk.initializers.TruncatedNormal(0.001), b_init=jnp.zeros, @@ -188,37 +184,18 @@ def run(n_rounds, n_iter): sample_key, rng_key = jr.split(jr.PRNGKey(3)) inference_results, _ = snl.sample_posterior(sample_key, params, y_obs) - sample_key, rng_key = jr.split(rng_key) - log_density_partial = partial(log_density_fn, y=y_obs) - log_density = lambda x: log_density_partial(**x) - slice_samples = sample_with_nuts( - sample_key, - log_density, - prior_fn().sample, - ) - slice_samples = slice_samples['theta'].reshape(-1, 5) - snl_samples = inference_data_as_dictionary(inference_results.posterior)["theta"] - - g = sns.PairGrid(pd.DataFrame(slice_samples)) - g.map_upper(sns.scatterplot, color="black", marker=".", edgecolor=None, s=2) - g.map_diag(plt.hist, color="black") - for ax in g.axes.flatten(): - ax.set_xlim(-5, 5) - ax.set_ylim(-5, 5) - g.fig.set_figheight(5) - g.fig.set_figwidth(5) - plt.show() - - fig, axes = plt.subplots(5, 2) + samples = inference_data_as_dictionary(inference_results.posterior)["theta"] + _, axes = plt.subplots(figsize=(12, 10), nrows=5, ncols=5) + for i in range(0, 5): + for j in range(0, 5): + ax = axes[i, j] + if i < j: + ax.axis('off') + else: + ax.hexbin(samples[..., j], samples[..., i], gridsize=50, + bins='log') for i in range(5): - sns.histplot(slice_samples[:, i], color="darkgrey", ax=axes[i, 0]) - sns.histplot(snl_samples[:, i], color="darkblue", ax=axes[i, 1]) - axes[i, 0].set_title(rf"Sampled posterior $\theta_{i}$") - axes[i, 1].set_title(rf"Approximated posterior $\theta_{i}$") - for j in range(2): - axes[i, j].set_xlim(-5, 5) - sns.despine() - plt.tight_layout() + axes[i, i].hist(samples[..., i], color="black") plt.show()