diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8ea8a9f..c9ae2a9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,6 +40,8 @@ repos: language: python language_version: python3 types: [python] + args: ["-c", "pyproject.toml"] + additional_dependencies: ["toml"] files: "(sbijax|examples)" - repo: https://github.com/PyCQA/flake8 @@ -58,6 +60,17 @@ repos: args: ["--ignore-missing-imports"] files: "(sbijax|examples)" +- repo: https://github.com/nbQA-dev/nbQA + rev: 1.6.3 + hooks: + - id: nbqa-black + - id: nbqa-pyupgrade + args: [--py39-plus] + - id: nbqa-isort + args: ['--profile=black'] + - id: nbqa-flake8 + args: ['--ignore=E501,E203,E302,E402,E731,W503'] + - repo: https://github.com/jorisroovers/gitlint rev: v0.18.0 hooks: diff --git a/README.md b/README.md index 55c3e38..1acce8d 100644 --- a/README.md +++ b/README.md @@ -2,13 +2,14 @@ [![status](http://www.repostatus.org/badges/latest/concept.svg)](http://www.repostatus.org/#concept) [![ci](https://github.com/dirmeier/sbijax/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/sbijax/actions/workflows/ci.yaml) +[![version](https://img.shields.io/pypi/v/sbijax.svg?colorB=black&style=flat)](https://pypi.org/project/sbijax/) > Simulation-based inference in JAX ## About SbiJAX implements several algorithms for simulation-based inference using -[BlackJAX](https://github.com/blackjax-devs/blackjax), [Haiku](https://github.com/deepmind/dm-haiku) and [JAX](https://github.com/google/jax). +[JAX](https://github.com/google/jax), [Haiku](https://github.com/deepmind/dm-haiku) and [BlackJAX](https://github.com/blackjax-devs/blackjax). SbiJAX so far implements @@ -37,29 +38,6 @@ To install the latest GitHub , use: pip install git+https://github.com/dirmeier/sbijax@ ``` -## Contributing - -Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled -["good first issue"](https://github.com/dirmeier/sbijax/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22). In order to contribute: - -1) Fork the repository and install `hatch` and `pre-commit` - -```bash -pip install hatch pre-commit -pre-commit install -``` - -2) Create a new branch in your fork and implement your contribution - -3) Test your contribution/implementation by calling `hatch run test` on the (Unix) command line before submitting a PR - -```bash -hatch run test:lint -hatch run test:test -``` - -4) Submit a pull request :slightly_smiling_face: - ## Author Simon Dirmeier sfyrbnd @ pm me diff --git a/pyproject.toml b/pyproject.toml index feee0fa..f157307 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,12 +25,16 @@ dependencies = [ "dm-haiku>=0.0.9", "flax>=0.6.3", "optax>=0.1.3", + "surjectors@git+https://git@github.com/dirmeier/surjectors@v0.2.2", ] dynamic = ["version"] [project.urls] homepage = "https://github.com/dirmeier/sbijax" +[tool.hatch.metadata] +allow-direct-references = true + [tool.hatch.version] path = "sbijax/__init__.py" @@ -50,7 +54,7 @@ dependencies = [ [tool.hatch.envs.test.scripts] lint = 'pylint sbijax' -test = 'pytest -v --doctest-modules --cov=./sbi --cov-report=xml sbijax' +test = 'pytest -v --doctest-modules --cov=./sbijax --cov-report=xml sbijax' [tool.black] diff --git a/sbijax/__init__.py b/sbijax/__init__.py index 8e3bb0f..004def8 100644 --- a/sbijax/__init__.py +++ b/sbijax/__init__.py @@ -2,7 +2,7 @@ sbijax: Simulation-based inference in JAX """ -__version__ = "0.0.10" +__version__ = "0.0.11" from sbijax.abc.rejection_abc import RejectionABC diff --git a/sbijax/snl.py b/sbijax/snl.py index e893b80..691e2c5 100644 --- a/sbijax/snl.py +++ b/sbijax/snl.py @@ -5,6 +5,8 @@ import numpy as np import optax from absl import logging + +# TODO(simon): this is a bit an annoying dependency to have from flax.training.early_stopping import EarlyStopping from jax import numpy as jnp diff --git a/sbijax/snl_test.py b/sbijax/snl_test.py index 9cb15ed..b42b7b2 100644 --- a/sbijax/snl_test.py +++ b/sbijax/snl_test.py @@ -1,6 +1,80 @@ # pylint: skip-file -import chex + +import distrax +import haiku as hk +import optax +from jax import numpy as jnp +from surjectors import Chain, MaskedCoupling, TransformedDistribution +from surjectors.conditioners import mlp_conditioner +from surjectors.util import make_alternating_binary_mask + +from sbijax import SNL + + +def prior_model_fns(): + p = distrax.Independent( + distrax.Uniform(jnp.full(2, -3.0), jnp.full(2, 3.0)), 1 + ) + return p.sample, p.log_prob + + +def simulator_fn(seed, theta): + p = distrax.MultivariateNormalDiag(theta, 0.1 * jnp.ones_like(theta)) + y = p.sample(seed=seed) + return y + + +def log_density_fn(theta, y): + prior = distrax.Uniform(jnp.full(2, -3.0), jnp.full(2, 3.0)) + likelihood = distrax.MultivariateNormalDiag( + theta, 0.1 * jnp.ones_like(theta) + ) + + lp = jnp.sum(prior.log_prob(theta)) + jnp.sum(likelihood.log_prob(y)) + return lp + + +def make_model(dim): + def _bijector_fn(params): + means, log_scales = jnp.split(params, 2, -1) + return distrax.ScalarAffine(means, jnp.exp(log_scales)) + + def _flow(method, **kwargs): + layers = [] + for i in range(2): + mask = make_alternating_binary_mask(dim, i % 2 == 0) + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([8, 8, dim * 2]), + ) + layers.append(layer) + chain = Chain(layers) + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(dim), jnp.ones(dim)), + 1, + ) + td = TransformedDistribution(base_distribution, chain) + return td(method, **kwargs) + + td = hk.transform(_flow) + td = hk.without_apply_rng(td) + return td def test_snl(): - chex.assert_equal(1, 1) + rng_seq = hk.PRNGSequence(0) + y_observed = jnp.array([-1.0, 1.0]) + + prior_simulator_fn, prior_logdensity_fn = prior_model_fns() + fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn + + snl = SNL(fns, make_model(2)) + params, info = snl.fit( + next(rng_seq), + y_observed, + n_rounds=1, + optimizer=optax.adam(1e-4), + sampler="slice", + ) + _ = snl.sample_posterior(params, 2, 100, 50, sampler="slice")