Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for JAX batching via vmap #11

Open
dfm opened this issue Oct 29, 2020 · 2 comments
Open

Add support for JAX batching via vmap #11

dfm opened this issue Oct 29, 2020 · 2 comments
Labels
enhancement New feature or request

Comments

@dfm
Copy link
Member

dfm commented Oct 29, 2020

This will require updating the backend to iterate over the batch dimension, but that shouldn't be too terribly hard. Then, we'd need to add a simple batching function.

One question is how to interface batching with the terms interface.

https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html#Batching

@dfm dfm added the enhancement New feature or request label Oct 29, 2020
@bmorris3
Copy link

Hi @dfm! I'm starting to work with the jax backend, and I'm hitting an error:

NotImplementedError: Batching rule for 'celerite2_solve_lower_jvp' not implemented

Is that what you're referring to in this issue? Happy to help contribute if you give me some pointers on how to get started.

@bmorris3
Copy link

Ah I see this was because I left my NUTS(...,forward_mode_differentiation=True), sorry for the noise!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants