-
Notifications
You must be signed in to change notification settings - Fork 67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] [ENH] add faster jvp computation for lasso type problems #17
base: main
Are you sure you want to change the base?
Conversation
Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). 📝 Please visit https://cla.developers.google.com/ to sign. Once you've signed (or fixed any issues), please reply here with What to do if you already signed the CLAIndividual signers
Corporate signers
ℹ️ Googlers: Go here for more info. |
Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). 📝 Please visit https://cla.developers.google.com/ to sign. Once you've signed (or fixed any issues), please reply here with What to do if you already signed the CLAIndividual signers
Corporate signers
ℹ️ Googlers: Go here for more info. |
Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). 📝 Please visit https://cla.developers.google.com/ to sign. Once you've signed (or fixed any issues), please reply here with What to do if you already signed the CLAIndividual signers
Corporate signers
ℹ️ Googlers: Go here for more info. |
Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). 📝 Please visit https://cla.developers.google.com/ to sign. Once you've signed (or fixed any issues), please reply here with What to do if you already signed the CLAIndividual signers
Corporate signers
ℹ️ Googlers: Go here for more info. |
Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). 📝 Please visit https://cla.developers.google.com/ to sign. Once you've signed (or fixed any issues), please reply here with What to do if you already signed the CLAIndividual signers
Corporate signers
ℹ️ Googlers: Go here for more info. |
I believe the number of iterations required for CG to converge typically depends on the condition number of the linear operator, rather than size of the system. It's possible that restricting the support of the system does not change that. More generally, it seems like it would be a good idea to collect some metrics (e.g., via host_callback) about the rate of convergence. I would test with direct solvers |
@shoyer In the reference NumPy-based implementation of our paper, that also uses CG, @QB3 apparently observed speed ups by restricting the support (please confirm @QB3). |
Thanks a lot for your answers, I tried on a larger example and I observed some nice speedups! Time taken to solve the Lasso optimization problem 0.015 |
Algorithmically implicit differentiation is composed of 2 steps:
This linear system can be large, and expensive to solve.
When one differentiates the solution of sparse optimization problems (such as the Lasso), it is possible to reduce the size of the linear system to solve (http://proceedings.mlr.press/v119/bertrand20a/bertrand20a.pdf, https://arxiv.org/pdf/2105.01637.pdf).
The goal of this PR is to implement such an acceleration for sparse optimization problems.
To this aim, I implemented a
sparse_root_vjp
function (https://github.com/QB3/jaxopt/blob/c9b0daea3f90dec392fbcb73097eb88d39010aa2/jaxopt/_src/implicit_diff.py#L125) where I solve a smaller linear system (https://github.com/QB3/jaxopt/blob/c9b0daea3f90dec392fbcb73097eb88d39010aa2/jaxopt/_src/implicit_diff.py#L169).The correctness of the implementation is tested here https://github.com/QB3/jaxopt/blob/c9b0daea3f90dec392fbcb73097eb88d39010aa2/tests/implicit_diff_test.py#L88.
Note that this implementation is very brutal, and not general at all, the goal is to see if we observe any speedups.
To check the speed of the implementation I created a
sparse_vjp.py
file in thebenchmarks
directory (https://github.com/QB3/jaxopt/blob/add_sparse_vjp/benchmarks/sparse_vjp.py).I differentiate the solution of a Lasso with
(n_samples=10, n_features=1000)
,lam = lam_max / 2
.I have the following results:
Time taken to solve the Lasso optimization problem 0.008
Time taken to compute the Jacobian 2.168
Time taken to compute the Jacobian with the sparse implementation 2.168
This benchmark tells us 2 things:
I do not understand why we do not observe any speedups, does someone have a lead?