Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a Jax version of the ESN #12

Closed
2 tasks done
eliseoe opened this issue Nov 28, 2023 · 1 comment
Closed
2 tasks done

Create a Jax version of the ESN #12

eliseoe opened this issue Nov 28, 2023 · 1 comment
Assignees
Labels
enhancement New feature or request JAX

Comments

@eliseoe
Copy link
Collaborator

eliseoe commented Nov 28, 2023

Issue:

I think it is worth exploring the possibility of creating a JAX version of our current implementation to leverage the benefits and optimizations of JAX to speed up our ESN training. In my preliminary tests, the jitted version of step speeds up the computation when using a reservoir of 100. I believe the benefits will increase for larger reservoirs

Tasks:

  • Implement a JAX version of the standard ESN
  • Adapt the validation strategy to JAX

Additional Context:

Precisely, I believe a mix of scipy ridge regression and JAX will lead us to the fastest version. While the change to JAX in the standard ESN is mostly straightforward, there is more adjusting needed in validation.py

@eliseoe eliseoe added enhancement New feature or request JAX labels Nov 28, 2023
@eliseoe eliseoe self-assigned this Nov 28, 2023
@eliseoe
Copy link
Collaborator Author

eliseoe commented Jul 30, 2024

This task was preliminarily completed. A first Jax esn version is on the main branch, with the validation adapted.

Memory issue still exists (see Issue: Resolve gpu memory issue with JAX esn).

Closing this issue.

@eliseoe eliseoe closed this as completed Jul 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request JAX
Projects
None yet
Development

No branches or pull requests

1 participant