Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] [ENH] add faster jvp computation for lasso type problems #17

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6ef6ed0
[WIP] first draft for sparse vjp
QB3 Aug 26, 2021
64034b8
[ci skip] fixed call jax.vjp, tests still do not pass
QB3 Aug 26, 2021
71b8b2c
[ci skip] made test implicit diff for sparse_jvp pass
QB3 Aug 26, 2021
4e8a791
[ci skip] added test lasso, currently fails
QB3 Aug 26, 2021
a9883d6
[ci skip] added test lasso without sparsity
QB3 Aug 26, 2021
7733f9a
Trigger google-cla
QB3 Aug 26, 2021
cbe7bac
[ci skip] made test pass for lasso without sparse computation
QB3 Aug 26, 2021
478d871
[ci skip]@ new try for sparse computation, still fails
QB3 Aug 26, 2021
0346b97
[ci skip] made sparse jvp work, remains to see how much we win
QB3 Aug 26, 2021
8a3c128
[ci skip] added little bench for sparse vjp, sol is better but runnin…
QB3 Aug 26, 2021
c9b0dae
[ci skip] try implemetation with hardcoded support
QB3 Aug 27, 2021
b698e02
[ci skip] take larger number of features, see speed ups
QB3 Aug 30, 2021
0c230ce
add make_restricted_optimality_fun to sparse_vjp
QB3 Aug 30, 2021
933f287
[ci skip] simplified + rearanged args in tests
QB3 Aug 30, 2021
1ddeeb9
[ci skip] CLN
QB3 Aug 30, 2021
3aae541
[ci skip] made test_custom_root_lasso
QB3 Aug 30, 2021
3a3ef0b
[ci skip] added sparse custom root + tests
QB3 Aug 30, 2021
4669b96
[ci skip] adapted example lasso, toward a sparse implementation
QB3 Aug 30, 2021
d110ff1
[ci skip] added benchmark sparse custom root
QB3 Aug 31, 2021
d15ae55
[ci skip] added sparse_custom_root to implicit diff example
QB3 Aug 31, 2021
77661b1
improved benchmark file
QB3 Sep 12, 2021
a7e5436
jax.numpy.linalg >> onp.linalg.norm
QB3 Sep 12, 2021
a8ee7cc
[ci skip] added back version with hardcoded support
QB3 Sep 12, 2021
d7e5cc1
[ciskip] updated benchmark
QB3 Sep 13, 2021
1c59af7
[ciskip] added sparse_custom root with other implem, currently fails
QB3 Sep 13, 2021
91348aa
[ciskip] X.T@(X @ params - y) >> grad(obj.square)
QB3 Sep 13, 2021
c64c596
[ci skip] made test custom root work
QB3 Sep 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions benchmarks/sparse_custom_root.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import jax
import jax.numpy as jnp

from jaxopt import prox
from jaxopt import implicit_diff as idf
from jaxopt._src import test_util
from jaxopt import objective

from sklearn import datasets


def lasso_objective(params, lam, X, y):
residuals = jnp.dot(X, params) - y
return 0.5 * jnp.mean(residuals ** 2) / len(y) + lam * jnp.sum(
jnp.abs(params))


def lasso_solver(params, X, y, lam):
sol = test_util.lasso_skl(X, y, lam)
return sol


X, y = datasets.make_regression(
n_samples=10, n_features=10_000, random_state=0)
lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y)
lam = lam_max / 2
L = jax.numpy.linalg.norm(X, ord=2) ** 2


def make_restricted_optimality_fun(support):
def restricted_optimality_fun(restricted_params, X, y, lam):
# this is suboptimal, I would try to compute restricted_X once for all
restricted_X = X[:, support]
return lasso_optimality_fun(restricted_params, restricted_X, y, lam)
return restricted_optimality_fun


def lasso_optimality_fun(params, X, y, lam, tol=1e-4):
n_samples = X.shape[0]
return prox.prox_lasso(
params - jax.grad(objective.least_squares)(params, (X, y)) * n_samples / L, lam * len(y) / L) - params


t_start = time.time()
lasso_solver_decorated = idf.custom_root(lasso_optimality_fun)(lasso_solver)
sol = test_util.lasso_skl(X=X, y=y, lam=lam)
J = jax.jacrev(lasso_solver_decorated, argnums=3)(None, X, y, lam)
t_custom = time.time() - t_start


t_start = time.time()
lasso_solver_decorated = idf.sparse_custom_root(
lasso_optimality_fun, make_restricted_optimality_fun)(lasso_solver)
sol = test_util.lasso_skl(X=X, y=y, lam=lam)
J = jax.jacrev(lasso_solver_decorated, argnums=3)(None, X, y, lam)
t_custom_sparse = time.time() - t_start


print("Time taken to compute the Jacobian %.3f" % t_custom)
print("Time taken to compute the Jacobian with the sparse implementation %.3f" % t_custom_sparse)
104 changes: 104 additions & 0 deletions benchmarks/sparse_vjp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import time
import jax

import jax.numpy as jnp
import numpy as onp

from jaxopt import prox
from jaxopt import implicit_diff as idf
from jaxopt._src import test_util
from jaxopt import linear_solve
from jaxopt import objective

from sklearn import datasets

X, y = datasets.make_regression(
n_samples=100, n_features=100_000, random_state=0)

L = onp.linalg.norm(X, ord=2) ** 2


def optimality_fun(params, X, y, lam):
n_samples = X.shape[0]
return prox.prox_lasso(
params - jax.grad(objective.least_squares)(params, (X, y)) * n_samples / L,
lam * len(y) / L) - params


def make_restricted_optimality_fun(support):
def restricted_optimality_fun(restricted_params, X, y, lam):
# this is suboptimal, I would try to compute restricted_X once for all
restricted_X = X[:, support]
return optimality_fun(restricted_params, restricted_X, y, lam)
return restricted_optimality_fun


lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y)
lam = lam_max / 2
t_start = time.time()
sol = test_util.lasso_skl(X, y, lam)
t_optim = time.time() - t_start

onp.random.seed(0)
rand = onp.random.normal(0, 1, len(sol))
dict_times = {}
dict_grad = {}

for maxiter in [10, 100, 1000, 2000]:
def solve(matvec, b):
return linear_solve.solve_normal_cg(
matvec, b, None, tol=1e-32, maxiter=maxiter)

vjp = lambda g: idf.root_vjp(
optimality_fun=optimality_fun,
sol=sol,
args=(X, y, lam),
cotangent=g,
solve=solve)[2] # vjp w.r.t. lam

t_start = time.time()
grad = vjp(rand)
t_jac = time.time() - t_start
dict_times[maxiter] = t_jac
dict_grad[maxiter] = grad.copy()


def solve_sparse(matvec, b):
return linear_solve.solve_cg(
matvec, b, None, tol=1e-32, maxiter=(sol != 0).sum())


vjp_sparse = lambda g: idf.sparse_root_vjp(
optimality_fun=optimality_fun,
make_restricted_optimality_fun=make_restricted_optimality_fun,
sol=sol,
args=(X, y, lam),
cotangent=g,
solve=solve_sparse)[2] # vjp w.r.t. lam

vjp_sparse2 = lambda g: idf.sparse_root_vjp2(
optimality_fun=optimality_fun,
sol=sol,
args=(X, y, lam),
cotangent=g,
solve=solve_sparse)[2] # vjp w.r.t. lam

t_start = time.time()
grad_sparse = vjp_sparse(rand)
t_jac_sparse = time.time() - t_start

t_start = time.time()
grad_sparse2 = vjp_sparse(rand)
t_jac_sparse2 = time.time() - t_start

print("Time taken to solve the Lasso optimization problem %.3f" % t_optim)
for maxiter in dict_times.keys():
print("Time taken to compute the gradient with n= %i iterations %.3f | distance to the sparse gradient %.e" % (
maxiter, dict_times[maxiter], jnp.linalg.norm(dict_grad[maxiter] - grad_sparse) / grad_sparse))
print("Time taken to compute the gradient with the sparse implementation %.3f" % t_jac_sparse)
print("Time taken to compute the gradient with the sparse2 implementation %.3f" % t_jac_sparse2)


# Computation time are the same, which is very weird to me
# However, the Jacobian computed the sparse way is much closer to the real
# Jacobian
122 changes: 122 additions & 0 deletions examples/lasso_implicit_diff_sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implicit differentiation of the lasso based on a sparse implementation."""

import time
from absl import app
import jax
import jax.numpy as jnp
import numpy as onp
from jaxopt import implicit_diff
from jaxopt import linear_solve
from jaxopt import OptaxSolver
from jaxopt import prox
from jaxopt import objective
from jaxopt._src import test_util
import optax
from sklearn import datasets
from sklearn import model_selection
from sklearn import preprocessing

# def main(argv):
# del argv

# Prepare data.
# X, y = datasets.load_boston(return_X_y=True)

X, y = datasets.make_regression(
n_samples=30, n_features=10_000, random_state=0)

# X = preprocessing.normalize(X)
# data = (X_tr, X_val, y_tr, y_val)
data = model_selection.train_test_split(X, y, test_size=0.33, random_state=0)

L = onp.linalg.norm(X, ord=2) ** 2


def optimality_fun(params, lam, data):
X, y = data
n_samples = X.shape[0]
return prox.prox_lasso(
params - jax.grad(objective.least_squares)(params, (X, y)) * n_samples / L,
lam * len(y) / L) - params


def make_restricted_optimality_fun(support):
def restricted_optimality_fun(restricted_params, lam, data):
# this is suboptimal, I would try to compute restricted_X once for all
X, y = data
restricted_X = X[:, support]
return optimality_fun(restricted_params, lam, (restricted_X, y))
return restricted_optimality_fun


@implicit_diff.sparse_custom_root(
optimality_fun=optimality_fun,
make_restricted_optimality_fun=make_restricted_optimality_fun)
def lasso_solver(init_params, lam, data):
"""Solve Lasso."""
X_tr, y_tr = data
# TODO add warm start?
sol = test_util.lasso_skl(X, y, lam)
return sol

# @implicit_diff.custom_root(
# optimality_fun=optimality_fun)
# def lasso_solver(init_params, lam, data):
# """Solve Lasso."""
# X_tr, y_tr = data
# # TODO add warm start?
# sol = test_util.lasso_skl(X, y, lam)
# return sol


# Perhaps confusingly, theta is a parameter of the outer objective,
# but l2reg = jnp.exp(theta) is an hyper-parameter of the inner objective.
def outer_objective(theta, init_inner, data):
"""Validation loss."""
X_tr, X_val, y_tr, y_val = data
# We use the bijective mapping l2reg = jnp.exp(theta)
# both to optimize in log-space and to ensure positivity.
lam = jnp.exp(theta)
w_fit = lasso_solver(init_inner, lam, (X_tr, y_tr))
y_pred = jnp.dot(X_val, w_fit)
loss_value = jnp.mean((y_pred - y_val) ** 2)
# We return w_fit as auxiliary data.
# Auxiliary data is stored in the optimizer state (see below).
return loss_value, w_fit


# Initialize solver.
solver = OptaxSolver(opt=optax.adam(1e-2), fun=outer_objective, has_aux=True)
lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y)
lam = lam_max / 10
theta_init = jnp.log(lam)
theta, state = solver.init(theta_init)
init_w = jnp.zeros(X.shape[1])

t_start = time.time()
# Run outer loop.
for _ in range(10):
theta, state = solver.update(
params=theta, state=state, init_inner=init_w, data=data)
# The auxiliary data returned by the outer loss is stored in the state.
init_w = state.aux
print(f"[Step {state.iter_num}] Validation loss: {state.value:.3f}.")
t_ellapsed = time.time() - t_start

# if __name__ == "__main__":
# app.run(main)
print("Time taken for 10 iterations: %.2f" % t_ellapsed)
Loading