Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC - Jax implementation of AndersonCD solver #155

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

Badr-MOUFAD
Copy link
Collaborator

@Badr-MOUFAD Badr-MOUFAD commented Apr 27, 2023

Follow up of #149

This implements AndersonCD solver using Jax-GPU. it proceeds as follows:

  • CD solver using Jax
  • Working sets
  • Anderson acceleration
  • use autodiff
  • benchmarks against CPU AndersonCD

@Badr-MOUFAD Badr-MOUFAD marked this pull request as draft April 27, 2023 08:47
@Badr-MOUFAD
Copy link
Collaborator Author

Jax triggers another jit-compilation of functions whenever the function arguments change shape.
I open an issue on google/jax and it happens to be an inherent functioning of the xla compiler.

This is a limiting factor with the current design as the heavy-costly functions, gradient/subdiff_dist, and cd_epoch, have inputs, namely grad_ws, ws, that change shape along the iterations. Therefore most of the time is wasted on recompiling functions.

To bypass that, I'm thinking of tweaking the design to freeze the arrays' shapes across iterations and hence avoid the recompilation. I'm open to other suggestions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant