Skip to content

Commit

Permalink
Add more tests (#49)
Browse files Browse the repository at this point in the history
* Remove manuscript+supplement
* Deprecate Distrax where possible
* Loosen TF pins
  • Loading branch information
dirmeier authored Dec 3, 2024
1 parent 6176367 commit cc6fd38
Show file tree
Hide file tree
Showing 36 changed files with 42 additions and 4,055 deletions.
3 changes: 0 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,3 @@ lints:

format:
hatch run test:format

docs:
cd docs && make html
3 changes: 1 addition & 2 deletions examples/slcp-nass_smcabc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"""
import argparse

import distrax
import jax
from jax import numpy as jnp
from jax import random as jr
Expand Down Expand Up @@ -39,7 +38,7 @@ def _unpack_params(ps):
return m0, m1, s0, s1, r

m0, m1, s0, s1, r = _unpack_params(theta)
us = distrax.Normal(0.0, 1.0).sample(
us = tfd.Normal(0.0, 1.0).sample(
seed=us_key, sample_shape=(theta.shape[0], theta.shape[1], 4, 2)
)
xs = jnp.empty_like(us)
Expand Down
20 changes: 10 additions & 10 deletions examples/slcp-snle.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
"""
import argparse

import distrax
import haiku as hk
import jax
import matplotlib.pyplot as plt
import optax
import surjectors
from jax import numpy as jnp
from jax import random as jr
from jax import scipy as jsp
Expand Down Expand Up @@ -49,7 +49,7 @@ def _unpack_params(ps):
return m0, m1, s0, s1, r

m0, m1, s0, s1, r = _unpack_params(theta)
us = distrax.Normal(0.0, 1.0).sample(
us = tfd.Normal(0.0, 1.0).sample(
seed=us_key, sample_shape=(theta.shape[0], theta.shape[1], 4, 2)
)
xs = jnp.empty_like(us)
Expand All @@ -67,13 +67,13 @@ def likelihood_fn(theta, y):
corr = s1 * s2 * jnp.tanh(theta[4])
cov = jnp.array([[s1**2, corr], [corr, s2**2]])
cov = jsp.linalg.block_diag(*[cov for _ in range(4)])
p = distrax.MultivariateNormalFullCovariance(mu, cov)
p = tfd.MultivariateNormalFullCovariance(mu, cov)
return p.log_prob(y)


def log_density_fn(theta, y):
prior_lp = distrax.Independent(
distrax.Uniform(jnp.full(5, -3.0), jnp.full(5, 3.0)), 1
prior_lp = tfd.Independent(
tfd.Uniform(jnp.full(5, -3.0), jnp.full(5, 3.0)), 1
).log_prob(theta)
likelihood_lp = likelihood_fn(theta, y)

Expand All @@ -84,7 +84,7 @@ def log_density_fn(theta, y):
def make_model(dim, use_surjectors):
def _bijector_fn(params):
means, log_scales = unstack(params, -1)
return distrax.ScalarAffine(means, jnp.exp(log_scales))
return surjectors.ScalarAffine(means, jnp.exp(log_scales))

def _decoder_fn(n_dim):
decoder_net = make_mlp(
Expand All @@ -95,8 +95,8 @@ def _decoder_fn(n_dim):
def _fn(z):
params = decoder_net(z)
mu, log_scale = jnp.split(params, 2, -1)
return distrax.Independent(
distrax.Normal(mu, jnp.exp(log_scale)), 1
return tfd.Independent(
tfd.Normal(mu, jnp.exp(log_scale)), 1
)

return _fn
Expand Down Expand Up @@ -140,8 +140,8 @@ def _flow(method, **kwargs):
layers.append(Permutation(order, 1))
chain = Chain(layers)

base_distribution = distrax.Independent(
distrax.Normal(jnp.zeros(n_dimension), jnp.ones(n_dimension)),
base_distribution = tfd.Independent(
tfd.Normal(jnp.zeros(n_dimension), jnp.ones(n_dimension)),
reinterpreted_batch_ndims=1,
)
td = TransformedDistribution(base_distribution, chain)
Expand Down
Loading

0 comments on commit cc6fd38

Please sign in to comment.