A JAX-based python package for maximum entropy modeling of multivariate binary data.
JaxEnt is a small, lightweight python package for fitting, sampling from and computing various quantities of maximum entropy distributions with arbitrary constraints. As the name suggests, JaxEnt uses JAX to get JIT compilation of various function to CPU/GPU/TPU. JaxEnt implements several popular maximum entropy models (see below), and extending it to other usecases is straightforward.
JaxEnt is a research project under active development (as is JAX itself). Expect NotImplementedError
-s, and possibly future API breaking changes as JaxEnt gradually supports more usecases. Contributions, feature requests, additions, corrections and suggestions are welcomed.
Installation is simple:
git clone https://github.com/adamhaber/jaxent.git
cd jaxent
pip install .
To make sure everything works as planned, run:
cd jaxent
pytest
Maximum entropy distributions over binary variables are very common in a wide variety of fields and applications. Examples include:
- Pairwise, K-Ising and Random Projection models in neuroscience
- Exponential Random Graph Models (ERGMs) in networks science
- Ising Model in statistical physics
- Restricted Boltzmann Machines in machine learning
Here's an example of generating fake data from one Ising model, and fitting a different model to the same data:
import jaxent
import numpy as onp
import jax.numpy as np
import matplotlib.pyplot as plt
import jax
N = 15
n_data_points = 10000
# create an all-to-all-connectivity Ising model with random biases and couplings
m = jaxent.Ising(N)
m.factors = np.array(onp.concatenate([onp.random.normal(3,1,N),onp.random.normal(-0.1,0.05,N*(N-1)//2)]))
# sample from the model
emp_data = m.sample(jax.random.PRNGKey(0),n_data_points)
# create a new model and train it using the data generate from the first model
m2 = jaxent.Ising(N)
m2.train(emp_data)
The marginals of m2
are all within the (normalized) errorbars of the original data:
- Further improve performance of sampling and training methods
- Implement adaptive binning in Wang-Landau algorithm
- Expand tests suite
- Sparse matrices support
- Add notebooks and examples
- Big thanks to Ori Maoz who wrote the original, excellent MATLAB maxent toolbox for our lab. I plagiarized large parts of his API, design choices, etc - with permission of course. :-)