Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] [ENH] add faster jvp computation for lasso type problems #17

Open
wants to merge 27 commits into
base: main
Choose a base branch
from

Conversation

QB3
Copy link

@QB3 QB3 commented Aug 26, 2021

Algorithmically implicit differentiation is composed of 2 steps:

  • 1 compute the solution of the optimization problem
  • 2 solve a linear system.

This linear system can be large, and expensive to solve.
When one differentiates the solution of sparse optimization problems (such as the Lasso), it is possible to reduce the size of the linear system to solve (http://proceedings.mlr.press/v119/bertrand20a/bertrand20a.pdf, https://arxiv.org/pdf/2105.01637.pdf).

The goal of this PR is to implement such an acceleration for sparse optimization problems.

To this aim, I implemented a sparse_root_vjp function (https://github.com/QB3/jaxopt/blob/c9b0daea3f90dec392fbcb73097eb88d39010aa2/jaxopt/_src/implicit_diff.py#L125) where I solve a smaller linear system (https://github.com/QB3/jaxopt/blob/c9b0daea3f90dec392fbcb73097eb88d39010aa2/jaxopt/_src/implicit_diff.py#L169).
The correctness of the implementation is tested here https://github.com/QB3/jaxopt/blob/c9b0daea3f90dec392fbcb73097eb88d39010aa2/tests/implicit_diff_test.py#L88.
Note that this implementation is very brutal, and not general at all, the goal is to see if we observe any speedups.

To check the speed of the implementation I created a sparse_vjp.py file in the benchmarks directory (https://github.com/QB3/jaxopt/blob/add_sparse_vjp/benchmarks/sparse_vjp.py).
I differentiate the solution of a Lasso with (n_samples=10, n_features=1000), lam = lam_max / 2.

I have the following results:
Time taken to solve the Lasso optimization problem 0.008
Time taken to compute the Jacobian 2.168
Time taken to compute the Jacobian with the sparse implementation 2.168

This benchmark tells us 2 things:

  • there is room for improvement when computing the Jacobian (it takes 1000 more times than solving the optimization problem)
  • the sparse implementation does not provide speedups.

I do not understand why we do not observe any speedups, does someone have a lead?

@google-cla
Copy link

google-cla bot commented Aug 26, 2021

Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

📝 Please visit https://cla.developers.google.com/ to sign.

Once you've signed (or fixed any issues), please reply here with @googlebot I signed it! and we'll verify it.


What to do if you already signed the CLA

Individual signers
Corporate signers

ℹ️ Googlers: Go here for more info.

@google-cla google-cla bot added the cla: no label Aug 26, 2021
@google-cla
Copy link

google-cla bot commented Aug 26, 2021

Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

📝 Please visit https://cla.developers.google.com/ to sign.

Once you've signed (or fixed any issues), please reply here with @googlebot I signed it! and we'll verify it.


What to do if you already signed the CLA

Individual signers
Corporate signers

ℹ️ Googlers: Go here for more info.

@google-cla
Copy link

google-cla bot commented Aug 26, 2021

Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

📝 Please visit https://cla.developers.google.com/ to sign.

Once you've signed (or fixed any issues), please reply here with @googlebot I signed it! and we'll verify it.


What to do if you already signed the CLA

Individual signers
Corporate signers

ℹ️ Googlers: Go here for more info.

@google-cla
Copy link

google-cla bot commented Aug 26, 2021

Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

📝 Please visit https://cla.developers.google.com/ to sign.

Once you've signed (or fixed any issues), please reply here with @googlebot I signed it! and we'll verify it.


What to do if you already signed the CLA

Individual signers
Corporate signers

ℹ️ Googlers: Go here for more info.

@google-cla
Copy link

google-cla bot commented Aug 26, 2021

Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

📝 Please visit https://cla.developers.google.com/ to sign.

Once you've signed (or fixed any issues), please reply here with @googlebot I signed it! and we'll verify it.


What to do if you already signed the CLA

Individual signers
Corporate signers

ℹ️ Googlers: Go here for more info.

@google-cla google-cla bot added cla: yes and removed cla: no labels Aug 26, 2021
@mblondel
Copy link
Collaborator

@QB3 updated the pull request description with the current state of this pull request. Currently, he doesn't observe any speed ups despite restricting the linear system to the support.

@shoyer @froystig @fabianp If you have any idea what could make things faster, let us know.

@shoyer
Copy link
Member

shoyer commented Aug 27, 2021

@QB3 updated the pull request description with the current state of this pull request. Currently, he doesn't observe any speed ups despite restricting the linear system to the support.

I believe the number of iterations required for CG to converge typically depends on the condition number of the linear operator, rather than size of the system. It's possible that restricting the support of the system does not change that. More generally, it seems like it would be a good idea to collect some metrics (e.g., via host_callback) about the rate of convergence.

I would test with direct solvers jax.scipy.linalg.solve (on small systems), which likely depend more directly on the size of the system. You also might try calculating the condition number of the linear operators, e.g., via the approximate eigenvalue calcualtions from scipy.sparse.linalg.

@froystig
Copy link
Member

Might it also be useful to profile (tutorial, docs) for any lower-level surprises?

@mblondel
Copy link
Collaborator

@shoyer In the reference NumPy-based implementation of our paper, that also uses CG, @QB3 apparently observed speed ups by restricting the support (please confirm @QB3).

@QB3
Copy link
Author

QB3 commented Aug 30, 2021

Thanks a lot for your answers, I tried on a larger example and I observed some nice speedups!
(n=10, p=10_000)

Time taken to solve the Lasso optimization problem 0.015
Time taken to compute the Jacobian 24.673
Time taken to compute the Jacobian with the sparse implementation 1.435

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants