Skip to content

Latest commit

 

History

History
71 lines (52 loc) · 3.13 KB

README.md

File metadata and controls

71 lines (52 loc) · 3.13 KB

JaxEnt

A JAX-based python package for maximum entropy modeling of multivariate binary data.

What is JaxEnt?

JaxEnt is a small, lightweight python package for fitting, sampling from and computing various quantities of maximum entropy distributions with arbitrary constraints. As the name suggests, JaxEnt uses JAX to get JIT compilation of various function to CPU/GPU/TPU. JaxEnt implements several popular maximum entropy models (see below), and extending it to other usecases is straightforward.

JaxEnt is a research project under active development (as is JAX itself). Expect NotImplementedError-s, and possibly future API breaking changes as JaxEnt gradually supports more usecases. Contributions, feature requests, additions, corrections and suggestions are welcomed.

Installation

Installation is simple:

git clone https://github.com/adamhaber/jaxent.git
cd jaxent
pip install .

Testing

To make sure everything works as planned, run:

cd jaxent
pytest

Examples

Maximum entropy distributions over binary variables are very common in a wide variety of fields and applications. Examples include:

Here's an example of generating fake data from one Ising model, and fitting a different model to the same data:

import jaxent
import numpy as onp
import jax.numpy as np
import matplotlib.pyplot as plt
import jax

N = 15
n_data_points = 10000

# create an all-to-all-connectivity Ising model with random biases and couplings
m = jaxent.Ising(N)
m.factors = np.array(onp.concatenate([onp.random.normal(3,1,N),onp.random.normal(-0.1,0.05,N*(N-1)//2)]))

# sample from the model
emp_data = m.sample(jax.random.PRNGKey(0),n_data_points)

# create a new model and train it using the data generate from the first model
m2 = jaxent.Ising(N)
m2.train(emp_data)

The marginals of m2 are all within the (normalized) errorbars of the original data:

readme figure

Future Work

  • Further improve performance of sampling and training methods
  • Implement adaptive binning in Wang-Landau algorithm
  • Expand tests suite
  • Sparse matrices support
  • Add notebooks and examples

Thanks

  • Big thanks to Ori Maoz who wrote the original, excellent MATLAB maxent toolbox for our lab. I plagiarized large parts of his API, design choices, etc - with permission of course. :-)