Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] [ENH] add faster jvp computation for lasso type problems #17

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6ef6ed0
[WIP] first draft for sparse vjp
QB3 Aug 26, 2021
64034b8
[ci skip] fixed call jax.vjp, tests still do not pass
QB3 Aug 26, 2021
71b8b2c
[ci skip] made test implicit diff for sparse_jvp pass
QB3 Aug 26, 2021
4e8a791
[ci skip] added test lasso, currently fails
QB3 Aug 26, 2021
a9883d6
[ci skip] added test lasso without sparsity
QB3 Aug 26, 2021
7733f9a
Trigger google-cla
QB3 Aug 26, 2021
cbe7bac
[ci skip] made test pass for lasso without sparse computation
QB3 Aug 26, 2021
478d871
[ci skip]@ new try for sparse computation, still fails
QB3 Aug 26, 2021
0346b97
[ci skip] made sparse jvp work, remains to see how much we win
QB3 Aug 26, 2021
8a3c128
[ci skip] added little bench for sparse vjp, sol is better but runnin…
QB3 Aug 26, 2021
c9b0dae
[ci skip] try implemetation with hardcoded support
QB3 Aug 27, 2021
b698e02
[ci skip] take larger number of features, see speed ups
QB3 Aug 30, 2021
0c230ce
add make_restricted_optimality_fun to sparse_vjp
QB3 Aug 30, 2021
933f287
[ci skip] simplified + rearanged args in tests
QB3 Aug 30, 2021
1ddeeb9
[ci skip] CLN
QB3 Aug 30, 2021
3aae541
[ci skip] made test_custom_root_lasso
QB3 Aug 30, 2021
3a3ef0b
[ci skip] added sparse custom root + tests
QB3 Aug 30, 2021
4669b96
[ci skip] adapted example lasso, toward a sparse implementation
QB3 Aug 30, 2021
d110ff1
[ci skip] added benchmark sparse custom root
QB3 Aug 31, 2021
d15ae55
[ci skip] added sparse_custom_root to implicit diff example
QB3 Aug 31, 2021
77661b1
improved benchmark file
QB3 Sep 12, 2021
a7e5436
jax.numpy.linalg >> onp.linalg.norm
QB3 Sep 12, 2021
a8ee7cc
[ci skip] added back version with hardcoded support
QB3 Sep 12, 2021
d7e5cc1
[ciskip] updated benchmark
QB3 Sep 13, 2021
1c59af7
[ciskip] added sparse_custom root with other implem, currently fails
QB3 Sep 13, 2021
91348aa
[ciskip] X.T@(X @ params - y) >> grad(obj.square)
QB3 Sep 13, 2021
c64c596
[ci skip] made test custom root work
QB3 Sep 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions benchmarks/sparse_vjp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import time
import jax

import jax.numpy as jnp

from jaxopt import prox
from jaxopt import implicit_diff as idf
from jaxopt._src import test_util


from sklearn import datasets

X, y = datasets.make_regression(n_samples=10, n_features=1000, random_state=0)

L = jax.numpy.linalg.norm(X, ord=2) ** 2


def optimality_fun(params, lam, X, y):
return prox.prox_lasso(
params - X.T @ (X @ params - y) / L, lam * len(y) / L) - params


def optimality_fun_sparse(params, lam, X, y):
support = params != 0
res = X[:, support].T @ (X[:, support] @ params[support] - y) / L
res = params[support] - res
res = prox.prox_lasso(res, lam * len(y) / L)
res -= params[support]
return res


lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y)
lam = lam_max / 2
t_start = time.time()
sol = test_util.lasso_skl(X, y, lam)
t_optim = time.time() - t_start

vjp = lambda g: idf.root_vjp(optimality_fun=optimality_fun,
sol=sol,
args=(lam, X, y),
cotangent=g)[0] # vjp w.r.t. lam

vjp_sparse = lambda g: idf.sparse_root_vjp(
optimality_fun=optimality_fun_sparse,
sol=sol,
args=(lam, X, y),
cotangent=g)[0] # vjp w.r.t. lam

t_start = time.time()
I = jnp.eye(len(sol))
J = jax.vmap(vjp)(I)
t_jac = time.time() - t_start

t_start = time.time()
I = jnp.eye(len(sol))
J_sparse = jax.vmap(vjp_sparse)(I)
t_jac_sparse = time.time() - t_start

print("Time taken to solve the Lasso optimization problem %.3f" % t_optim)
print("Time taken to compute the Jacobian %.3f" % t_jac)
print("Time taken to compute the Jacobian with the sparse implementation %.3f" % t_jac)


# Computation time are the same, which is very weird to me
# However, the Jacobian computed the sparse way is much closer to the real
# Jacobian
104 changes: 104 additions & 0 deletions jaxopt/_src/implicit_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
from typing import Callable
from typing import Tuple

import numpy as np # to be removed, this is for the first draft
import jax

from jaxopt._src import linear_solve
from jaxopt._src.tree_util import tree_scalar_mul
from jaxopt._src.tree_util import tree_sub
from jaxopt._src.tree_util import tree_zeros_like


def root_vjp(optimality_fun: Callable,
Expand Down Expand Up @@ -73,6 +75,108 @@ def fun_args(*args):
return vjp_fun_args(u)


def root_vjp(optimality_fun: Callable,
sol: Any,
args: Tuple,
cotangent: Any,
solve: Callable = linear_solve.solve_normal_cg) -> Any:
"""Vector-Jacobian product of a root.

The invariant is ``optimality_fun(sol, *args) == 0``.

Args:
optimality_fun: the optimality function to use.
sol: solution / root (pytree).
args: tuple containing the arguments with respect to which we wish to
differentiate ``sol`` against.
cotangent: vector to left-multiply the Jacobian with
(pytree, same structure as ``sol``).
solve: a linear solver of the form, ``x = solve(matvec, b)``,
where ``matvec(x) = Ax`` and ``Ax=b``.
Returns:
vjps: tuple of the same length as ``len(args)`` containing the vjps w.r.t.
each argument. Each ``vjps[i]` has the same pytree structure as
``args[i]``.
"""
def fun_sol(sol):
# We close over the arguments.
return optimality_fun(sol, *args)

_, vjp_fun_sol = jax.vjp(fun_sol, sol)

# Compute the multiplication A^T u = (u^T A)^T.
matvec = lambda u: vjp_fun_sol(u)[0]

# The solution of A^T u = v, where
# A = jacobian(optimality_fun, argnums=0)
# v = -cotangent.
v = tree_scalar_mul(-1, cotangent)
u = solve(matvec, v)

def fun_args(*args):
# We close over the solution.
return optimality_fun(sol, *args)

_, vjp_fun_args = jax.vjp(fun_args, *args)

return vjp_fun_args(u)


def sparse_root_vjp(optimality_fun: Callable,
sol: Any,
args: Tuple,
cotangent: Any,
solve: Callable = linear_solve.solve_normal_cg) -> Any:
"""Sparse vector-Jacobian product of a root.

The invariant is ``optimality_fun(sol, *args) == 0``.

Args:
optimality_fun: the optimality function to use.
F in the paper
sol: solution / root (pytree).
args: tuple containing the arguments with respect to which we wish to
differentiate ``sol`` against.
cotangent: vector to left-multiply the Jacobian with
(pytree, same structure as ``sol``).
solve: a linear solver of the form, ``x = solve(matvec, b)``,
where ``matvec(x) = Ax`` and ``Ax=b``.
Returns:
vjps: tuple of the same length as ``len(args)`` containing the vjps w.r.t.
each argument. Each ``vjps[i]` has the same pytree structure as
``args[i]``.
"""
support = sol != 0 # nonzeros coefficients of the solution
restricted_sol = sol[support] # solution restricted to the support

def fun_sol(restricted_sol):
# We close over the arguments.
# Maybe this could be optimized
sol_ = tree_zeros_like(sol)
sol_ = jax.ops.index_update(sol_, support, restricted_sol)
return optimality_fun(sol_, *args)
QB3 marked this conversation as resolved.
Show resolved Hide resolved

_, vjp_fun_sol = jax.vjp(fun_sol, restricted_sol)

# Compute the multiplication A^T u = (u^T A)^T resticted to the support.
def restricted_matvec(restricted_v):
return vjp_fun_sol(restricted_v)[0]

# The solution of A^T u = v, where
# A = jacobian(optimality_fun, argnums=0)
# v = -cotangent.
restricted_v = tree_scalar_mul(-1, cotangent[support])
restricted_u = solve(restricted_matvec, restricted_v)

def fun_args(*args):
# We close over the solution.
return optimality_fun(sol, *args)

_, vjp_fun_args = jax.vjp(fun_args, *args)

return vjp_fun_args(restricted_u)


def _jvp_sol(optimality_fun, sol, args, tangent):
"""JVP in the first argument of optimality_fun."""
# We close over the arguments.
Expand Down
1 change: 1 addition & 0 deletions jaxopt/implicit_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
from jaxopt._src.implicit_diff import custom_fixed_point
from jaxopt._src.implicit_diff import root_jvp
from jaxopt._src.implicit_diff import root_vjp
from jaxopt._src.implicit_diff import sparse_root_vjp
59 changes: 59 additions & 0 deletions tests/implicit_diff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from numpy.core.numeric import ones
from absl.testing import absltest
from absl.testing import parameterized

import numpy as np
import jax
from jax import test_util as jtu
import jax.numpy as jnp

from jaxopt import prox
from jaxopt import implicit_diff as idf
from jaxopt._src import test_util
from jaxopt import objective

from sklearn import datasets

Expand All @@ -30,6 +34,12 @@ def ridge_objective(params, lam, X, y):
return 0.5 * jnp.mean(residuals ** 2) + 0.5 * lam * jnp.sum(params ** 2)


def lasso_objective(params, lam, X, y):
residuals = jnp.dot(X, params) - y
return 0.5 * jnp.mean(residuals ** 2) / len(y) + lam * jnp.sum(
jnp.abs(params))


# def ridge_solver(init_params, lam, X, y):
def ridge_solver(init_params, lam, X, y):
del init_params # not used
Expand All @@ -55,6 +65,55 @@ def test_root_vjp(self):
J_num = test_util.ridge_solver_jac(X, y, lam, eps=1e-4)
self.assertArraysAllClose(J, J_num, atol=5e-2)

def test_lasso_root_vjp(self):
X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0)
L = jax.numpy.linalg.norm(X, ord=2) ** 2

def optimality_fun(params, lam, X, y):
return prox.prox_lasso(
params - X.T @ (X @ params - y) / L, lam * len(y) / L) - params
QB3 marked this conversation as resolved.
Show resolved Hide resolved

lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y)
lam = lam_max / 2
sol = test_util.lasso_skl(X, y, lam)
vjp = lambda g: idf.root_vjp(optimality_fun=optimality_fun,
sol=sol,
args=(lam, X, y),
cotangent=g)[0] # vjp w.r.t. lam
I = jnp.eye(len(sol))
J = jax.vmap(vjp)(I)
J_num = test_util.lasso_skl_jac(X, y, lam, eps=1e-4)
self.assertArraysAllClose(J, J_num, atol=5e-2)

def test_lasso_sparse_root_vjp(self):
X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0)

L = jax.numpy.linalg.norm(X, ord=2) ** 2

def optimality_fun(params, lam, X, y):
support = params != 0
res = X[:, support].T @ (X[:, support] @ params[support] - y) / L
res = params[support] - res
res = prox.prox_lasso(res, lam * len(y) / L)
res -= params[support]
return res

lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y)
lam = lam_max / 2
sol = test_util.lasso_skl(X, y, lam)

# jax.jacobian(optimality_fun)(jnp.ones(X.shape[1]), lam, X, y)
# test the mask in optimality_fun

vjp = lambda g: idf.sparse_root_vjp(optimality_fun=optimality_fun,
sol=sol,
args=(lam, X, y),
cotangent=g)[0] # vjp w.r.t. lam
I = jnp.eye(len(sol))
J = jax.vmap(vjp)(I)
J_num = test_util.lasso_skl_jac(X, y, lam, eps=1e-4)
self.assertArraysAllClose(J, J_num, atol=5e-2)

def test_root_jvp(self):
X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0)
optimality_fun = jax.grad(ridge_objective)
Expand Down