Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] [ENH] add faster jvp computation for lasso type problems #17

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6ef6ed0
[WIP] first draft for sparse vjp
QB3 Aug 26, 2021
64034b8
[ci skip] fixed call jax.vjp, tests still do not pass
QB3 Aug 26, 2021
71b8b2c
[ci skip] made test implicit diff for sparse_jvp pass
QB3 Aug 26, 2021
4e8a791
[ci skip] added test lasso, currently fails
QB3 Aug 26, 2021
a9883d6
[ci skip] added test lasso without sparsity
QB3 Aug 26, 2021
7733f9a
Trigger google-cla
QB3 Aug 26, 2021
cbe7bac
[ci skip] made test pass for lasso without sparse computation
QB3 Aug 26, 2021
478d871
[ci skip]@ new try for sparse computation, still fails
QB3 Aug 26, 2021
0346b97
[ci skip] made sparse jvp work, remains to see how much we win
QB3 Aug 26, 2021
8a3c128
[ci skip] added little bench for sparse vjp, sol is better but runnin…
QB3 Aug 26, 2021
c9b0dae
[ci skip] try implemetation with hardcoded support
QB3 Aug 27, 2021
b698e02
[ci skip] take larger number of features, see speed ups
QB3 Aug 30, 2021
0c230ce
add make_restricted_optimality_fun to sparse_vjp
QB3 Aug 30, 2021
933f287
[ci skip] simplified + rearanged args in tests
QB3 Aug 30, 2021
1ddeeb9
[ci skip] CLN
QB3 Aug 30, 2021
3aae541
[ci skip] made test_custom_root_lasso
QB3 Aug 30, 2021
3a3ef0b
[ci skip] added sparse custom root + tests
QB3 Aug 30, 2021
4669b96
[ci skip] adapted example lasso, toward a sparse implementation
QB3 Aug 30, 2021
d110ff1
[ci skip] added benchmark sparse custom root
QB3 Aug 31, 2021
d15ae55
[ci skip] added sparse_custom_root to implicit diff example
QB3 Aug 31, 2021
77661b1
improved benchmark file
QB3 Sep 12, 2021
a7e5436
jax.numpy.linalg >> onp.linalg.norm
QB3 Sep 12, 2021
a8ee7cc
[ci skip] added back version with hardcoded support
QB3 Sep 12, 2021
d7e5cc1
[ciskip] updated benchmark
QB3 Sep 13, 2021
1c59af7
[ciskip] added sparse_custom root with other implem, currently fails
QB3 Sep 13, 2021
91348aa
[ciskip] X.T@(X @ params - y) >> grad(obj.square)
QB3 Sep 13, 2021
c64c596
[ci skip] made test custom root work
QB3 Sep 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion tests/implicit_diff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
from absl.testing import absltest
from absl.testing import parameterized

import numpy as np
import jax
from jax import test_util as jtu
import jax.numpy as jnp

from jaxopt import prox
from jaxopt import implicit_diff as idf
from jaxopt._src import test_util

Expand Down Expand Up @@ -77,7 +79,12 @@ def test_root_vjp(self):

def test_lasso_root_vjp(self):
X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0)
optimality_fun = jax.grad(lasso_objective)
L = jax.numpy.linalg.norm(X, ord=2) ** 2

def optimality_fun(params, lam, X, y):
return prox.prox_lasso(
params - X.T @ (X @ params - y) / L, lam * len(y) / L) - params
QB3 marked this conversation as resolved.
Show resolved Hide resolved

lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y)
lam = lam_max / 2
sol = test_util.lasso_skl(X, y, lam)
Expand Down