Skip to content

Commit

Permalink
Major refactor of codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier committed Jul 17, 2024
1 parent 3f7edcb commit 996674f
Show file tree
Hide file tree
Showing 78 changed files with 5,579 additions and 1,780 deletions.
Binary file added .DS_Store
Binary file not shown.
11 changes: 0 additions & 11 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,6 @@ repos:
- id: requirements-txt-fixer
- id: trailing-whitespace

- repo: https://github.com/pycqa/bandit
rev: 1.7.1
hooks:
- id: bandit
language: python
language_version: python3
types: [python]
args: ["-c", "pyproject.toml"]
additional_dependencies: ["toml"]
files: "(sbijax|examples)"

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.910-1
hooks:
Expand Down
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
PKG_VERSION=`hatch version`

tag:
git tag -a v${PKG_VERSION} -m v${PKG_VERSION}
git push --tag
git tag -a v${PKG_VERSION} -m v${PKG_VERSION}
git push --tag

tests:
hatch run test:test
Expand All @@ -16,4 +16,4 @@ lints:
hatch run test:lint

docs:
cd docs && make html
cd docs && make html
66 changes: 59 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
## About

`sbijax` implements several algorithms for simulation-based inference in
[JAX](https://github.com/google/jax) using [Haiku](https://github.com/deepmind/dm-haiku),
[Distrax](https://github.com/deepmind/distrax) and [BlackJAX](https://github.com/blackjax-devs/blackjax). Specifically, `sbijax` implements
`sbijax` implements several algorithms for simulation-based inference in [JAX](https://github.com/google/jax), such as

- [Sequential Monte Carlo ABC](https://www.routledge.com/Handbook-of-Approximate-Bayesian-Computation/Sisson-Fan-Beaumont/p/book/9780367733728) (`SMCABC`)
- [Neural Likelihood Estimation](https://arxiv.org/abs/1805.07226) (`SNL`)
Expand All @@ -22,14 +20,44 @@
- [Flow matching posterior estimation](https://arxiv.org/abs/2305.17161) (`SFMPE`)
- [Consistency model posterior estimation](https://arxiv.org/abs/2312.05440) (`SCMPE`)

where the acronyms in parentheses denote the names of the methods in `sbijax`.
where the acronyms in parentheses denote the names of the classes in `sbijax`. It builds on the Python packages [Surjectors](https://github.com/dirmeier/surjectors), [Haiku](https://github.com/deepmind/dm-haiku),
[Distrax](https://github.com/deepmind/distrax) and [BlackJAX](https://github.com/blackjax-devs/blackjax).

> [!CAUTION]
> ⚠️ As per the LICENSE file, there is no warranty whatsoever for this free software tool. If you discover bugs, please report them.
## Examples

You can find several self-contained examples on how to use the algorithms in [examples](https://github.com/dirmeier/sbijax/tree/main/examples).
`sbijax` uses an object-oriented API with functional elements stemming from JAX. You can, for instance, define
a neural likelihood estimation method and generate posterior samples like this:

```python
import distrax
import optax
from jax import numpy as jnp, random as jr
from sbijax import SNL
from sbijax.nn import make_affine_maf

def prior_model_fns():
p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1)
return p.sample, p.log_prob

def simulator_fn(seed, theta):
p = distrax.Normal(jnp.zeros_like(theta), 1.0)
y = theta + p.sample(seed=seed)
return y

prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn
model = SNL(fns, make_affine_maf(2))

y_observed = jnp.array([-1.0, 1.0])
data, _ = model.simulate_data(jr.PRNGKey(0), n_simulations=5)
params, _ = model.fit(jr.PRNGKey(1), data=data, optimizer=optax.adam(0.001))
posterior, _ = model.sample_posterior(jr.PRNGKey(2), params, y_observed)
```

More self-contained examples can be found in [examples](https://github.com/dirmeier/sbijax/tree/main/examples).

## Documentation

Expand All @@ -52,11 +80,35 @@ To install the latest GitHub <RELEASE>, use:
pip install git+https://github.com/dirmeier/sbijax@<RELEASE>
```

## Contributing

Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled
[good first issue](https://github.com/dirmeier/sbijax/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22).

In order to contribute:

1) Clone `sbijax` and install `hatch` via `pip install hatch`,
2) create a new branch locally `git checkout -b feature/my-new-feature` or `git checkout -b issue/fixes-bug`,
3) implement your contribution and ideally a test case,
4) test it by calling `hatch run test` on the (Unix) command line,
5) submit a PR 🙂

## Citing sbijax

If you find our work relevant to your research, please consider citing:

```
@article{dirmeier2024sbijax,
author = {Simon Dirmeier and Antonietta Mira and Carlo Albert},
title = {Simulation-based inference with the Python Package sbijax},
year = {2024},
}
```

## Acknowledgements

> [!NOTE]
> 📝 The API of the package is heavily inspired by the excellent Pytorch-based [`sbi`](https://github.com/sbi-dev/sbi) package which is substantially more
feature-complete and user-friendly, and better documented.
> 📝 The API of the package is heavily inspired by the excellent Pytorch-based [`sbi`](https://github.com/sbi-dev/sbi) package.
## Author

Expand Down
Loading

0 comments on commit 996674f

Please sign in to comment.