Skip to content

Commit

Permalink
Merge pull request #142 from ita9naiwa:fr
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 424585493
  • Loading branch information
JAXopt authors committed Jan 27, 2022
2 parents fa9b14d + 8ca15c2 commit ea151c1
Show file tree
Hide file tree
Showing 8 changed files with 310 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Unconstrained
jaxopt.GradientDescent
jaxopt.LBFGS
jaxopt.ScipyMinimize
jaxopt.NonlinearCG

Constrained
~~~~~~~~~~~
Expand Down
3 changes: 2 additions & 1 deletion docs/unconstrained.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Solvers
jaxopt.GradientDescent
jaxopt.LBFGS
jaxopt.ScipyMinimize
jaxopt.NonlinearCG

Instantiating and running the solver
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -52,7 +53,7 @@ instantiated and run as follows::
# Alternatively, we could have used one of these solvers as well:
# solver = jaxopt.GradientDescent(fun=ridge_reg_objective, maxiter=500)
# solver = jaxopt.ScipyMinimize(fun=ridge_reg_objective, method="L-BFGS-B", maxiter=500)

# cg_model = jaxopt.NonlinearCG(fun=ridge_reg_objective, maxiter=300, method="polak-ribiere")
Unpacking results
~~~~~~~~~~~~~~~~~

Expand Down
1 change: 1 addition & 0 deletions jaxopt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from jaxopt._src.iterative_refinement import IterativeRefinement
from jaxopt._src.lbfgs import LBFGS
from jaxopt._src.mirror_descent import MirrorDescent
from jaxopt._src.nonlinear_cg import NonlinearCG
from jaxopt._src.optax_wrapper import OptaxSolver
from jaxopt._src.osqp import BoxOSQP
from jaxopt._src.osqp import OSQP
Expand Down
204 changes: 204 additions & 0 deletions jaxopt/_src/nonlinear_cg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Nonlinear conjugate gradient algorithm"""

from typing import Any
from typing import Callable
from typing import NamedTuple
from typing import Optional

from dataclasses import dataclass

import jax
import jax.numpy as jnp

from jaxopt._src import base
from jaxopt._src.backtracking_linesearch import BacktrackingLineSearch
from jaxopt.tree_util import tree_vdot
from jaxopt.tree_util import tree_scalar_mul
from jaxopt.tree_util import tree_add_scalar_mul
from jaxopt.tree_util import tree_sub
from jaxopt.tree_util import tree_div
from jaxopt.tree_util import tree_l2_norm


class NonlinearCGState(NamedTuple):
"""Named tuple containing state information."""
iter_num: int
stepsize: float
error: float
value: float
grad: any
descent_direction: jnp.ndarray
aux: Optional[Any] = None


@dataclass(eq=False)
class NonlinearCG(base.IterativeSolver):
"""
Nonlinear Conjugate Gradient Solver.
Attributes:
fun: a smooth function of the form ``fun(x, *args, **kwargs)``.
method: which variant to calculate the beta parameter in Nonlinear CG.
"polak-ribiere", "fletcher-reeves", "hestenes-stiefel"
(default: "polak-ribiere")
has_aux: whether function fun outputs one (False) or more values (True).
When True it will be assumed by default that fun(...)[0] is the objective.
maxiter: maximum number of proximal gradient descent iterations.
tol: tolerance of the stopping criterion.
maxls: maximum number of iterations to use in the line search.
decrease_factor: factor by which to decrease the stepsize during line search
(default: 0.8).
increase_factor: factor by which to increase the stepsize during line search
(default: 1.2).
implicit_diff: whether to enable implicit diff or autodiff of unrolled
iterations.
implicit_diff_solve: the linear system solver to use.
jit: whether to JIT-compile the optimization loop (default: "auto").
unroll: whether to unroll the optimization loop (default: "auto").
verbose: whether to print error on every iteration or not.
Warning: verbose=True will automatically disable jit.
Reference:
Jorge Nocedal and Stephen Wright.
Numerical Optimization, second edition.
Algorithm 5.4 (page 121).
"""

fun: Callable
has_aux: bool = False

maxiter: int = 100
tol: float = 1e-3

method: str = "polak-ribiere" # same as SciPy
condition: str = "strong-wolfe"
maxls: int = 15
decrease_factor: float = 0.8
increase_factor: float = 1.2
implicit_diff: bool = True
implicit_diff_solve: Optional[Callable] = None

jit: base.AutoOrBoolean = "auto"
unroll: base.AutoOrBoolean = "auto"

verbose: int = 0

def init_state(self,
init_params: Any,
*args,
**kwargs) -> NonlinearCGState:
"""Initialize the solver state.
Args:
init_params: pytree containing the initial parameters.
*args: additional positional arguments to be passed to ``fun``.
**kwargs: additional keyword arguments to be passed to ``fun``.
Returns:
state
"""
value, grad = self._value_and_grad_fun(init_params, *args, **kwargs)

return NonlinearCGState(iter_num=jnp.asarray(0),
stepsize=jnp.asarray(1.0),
error=jnp.asarray(jnp.inf),
value=value,
grad=grad,
descent_direction=tree_scalar_mul(-1.0, grad))

def update(self,
params: Any,
state: NonlinearCGState,
*args,
**kwargs) -> base.OptStep:
"""Performs one iteration of Fletcher-Reeves Algorithm.
Args:
params: pytree containing the parameters.
state: named tuple containing the solver state.
*args: additional positional arguments to be passed to ``fun``.
**kwargs: additional keyword arguments to be passed to ``fun``.
Returns:
(params, state)
"""

eps = 1e-6
value, grad, descent_direction = state.value, state.grad, state.descent_direction
init_stepsize = state.stepsize * self.increase_factor
ls = BacktrackingLineSearch(fun=self._value_and_grad_fun,
value_and_grad=True,
maxiter=self.maxls,
decrease_factor=self.decrease_factor,
condition=self.condition)
new_stepsize, ls_state = ls.run(init_stepsize=init_stepsize,
params=params,
value=value,
grad=grad,
*args, **kwargs)

new_params = tree_add_scalar_mul(params, new_stepsize, descent_direction)
(new_value, new_aux), new_grad = self._value_and_grad_with_aux(new_params, *args, **kwargs)

if self.method == "polak-ribiere":
# See Numerical Optimization, second edition, equation (5.44).
gTg = tree_vdot(grad, grad)
gTg = jnp.where(gTg >= eps, gTg, eps)
new_beta = tree_div(tree_vdot(new_grad, tree_sub(new_grad, grad)), gTg)
new_beta = jax.nn.relu(new_beta)
elif self.method == "fletcher-reeves":
# See Numerical Optimization, second edition, equation (5.41a).
gTg = tree_vdot(grad, grad)
gTg = jnp.where(gTg >= eps, gTg, eps)
new_beta = tree_div(tree_vdot(new_grad, new_grad), gTg)
elif self.method == 'hestenes-stiefel':
# See Numerical Optimization, second edition, equation (5.45).
grad_diff = tree_sub(new_grad, grad)
dTg = tree_vdot(descent_direction, grad_diff)
dTg = jnp.where(dTg >= eps, dTg, eps)
new_beta = tree_div(tree_vdot(new_grad, grad_diff), dTg)
else:
raise ValueError("method should be either 'polak-ribiere', 'fletcher-reeves', or 'hestenes-stiefel'")

new_descent_direction = tree_add_scalar_mul(tree_scalar_mul(-1, new_grad), new_beta, descent_direction)
new_state = NonlinearCGState(iter_num=state.iter_num + 1,
stepsize=jnp.asarray(new_stepsize),
error=tree_l2_norm(grad),
value=new_value,
grad=new_grad,
descent_direction=new_descent_direction,
aux=new_aux)

return base.OptStep(params=new_params, state=new_state)

def optimality_fun(self, params, *args, **kwargs):
"""Optimality function mapping compatible with ``@custom_root``."""
return self._grad_fun(params, *args, **kwargs)

def _value_and_grad_fun(self, params, *args, **kwargs):
(value, aux), grad = self._value_and_grad_with_aux(params, *args, **kwargs)
return value, grad

def _grad_fun(self, params, *args, **kwargs):
return self._value_and_grad_fun(params, *args, **kwargs)[1]

def __post_init__(self):
if self.has_aux:
self._fun = lambda *a, **kw: self.fun(*a, **kw)[0]
fun_with_aux = self.fun
else:
self._fun = self.fun
fun_with_aux = lambda *a, **kw: (self.fun(*a, **kw), None)

self._value_and_grad_with_aux = jax.value_and_grad(fun_with_aux,
has_aux=True)

self.reference_signature = self.fun
2 changes: 1 addition & 1 deletion jaxopt/_src/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def tree_vdot(tree_x, tree_y):

def tree_dot(tree_x, tree_y):
"""Compute leaves-wise dot product between pytree of arrays.
Useful to store block diagonal linear operators: each leaf of the tree
corresponds to a block."""
return tree_map(jnp.dot, tree_x, tree_y)
Expand Down
1 change: 1 addition & 0 deletions jaxopt/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from jaxopt._src.tree_util import tree_scalar_mul
from jaxopt._src.tree_util import tree_add_scalar_mul
from jaxopt._src.tree_util import tree_vdot
from jaxopt._src.tree_util import tree_div
from jaxopt._src.tree_util import tree_sum
from jaxopt._src.tree_util import tree_l2_norm
from jaxopt._src.tree_util import tree_zeros_like
91 changes: 91 additions & 0 deletions tests/nonlinear_cg_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
from absl.testing import parameterized

from jax import test_util as jtu
import jax.random
import jax.numpy as jnp

import numpy as onp

from jaxopt import NonlinearCG
from jaxopt import objective
from jaxopt._src import test_util
from sklearn import datasets


def get_random_pytree():
key = jax.random.PRNGKey(1213)

def rn(key, l=3):
return 0.05 * jnp.array(onp.random.normal(size=(10,)))

def _get_random_pytree(curr_depth=0, max_depth=3):
r = onp.random.uniform()
if curr_depth == max_depth or r <= 0.2: # leaf
return rn(key)
elif curr_depth <= 1 or r <= 0.7: # list
return [
_get_random_pytree(curr_depth=curr_depth + 1, max_depth=max_depth)
for _ in range(2)
]
else: # dict
return {
str(_): _get_random_pytree(
curr_depth=curr_depth + 1, max_depth=max_depth
)
for _ in range(2)
}
return [rn(key), {'a': rn(key), 'b': rn(key)}, _get_random_pytree()]


class NonlinearCGTest(jtu.JaxTestCase):
def test_arbitrary_pytree(self):
def loss(w, data):
X, y = data
_w = jnp.concatenate(jax.tree_util.tree_leaves(w))
return ((jnp.dot(X, _w) - y) ** 2).mean()

w = get_random_pytree()
f_w = jnp.concatenate(jax.tree_util.tree_leaves(w))
X, y = datasets.make_classification(n_samples=15, n_features=f_w.shape[-1], n_classes=2, n_informative=3, random_state=0)
data = (X, y)
cg_model = NonlinearCG(fun=loss, tol=1e-2, maxiter=300, method="polak-ribiere")
w_fit, info = cg_model.run(w, data=data)
self.assertLessEqual(info.error, 5e-2)

@parameterized.product(method=["fletcher-reeves", "polak-ribiere", "hestenes-stiefel"])
def test_binary_logreg(self, method):
X, y = datasets.make_classification(n_samples=10, n_features=5, n_classes=2,
n_informative=3, random_state=0)
data = (X, y)
fun = objective.binary_logreg

w_init = jnp.zeros(X.shape[1])
cg_model = NonlinearCG(fun=fun, tol=1e-3, maxiter=100, method=method)
w_fit, info = cg_model.run(w_init, data=data)

# Check optimality conditions.
self.assertLessEqual(info.error, 5e-2)

# Compare against sklearn.
w_skl = test_util.logreg_skl(X, y, 1e-6, fit_intercept=False,
multiclass=False)
self.assertArraysAllClose(w_fit, w_skl, atol=5e-2)


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())
9 changes: 9 additions & 0 deletions tests/tree_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ def test_tree_vdot(self):
got = tree_util.tree_vdot(self.tree_A, self.tree_B)
self.assertAllClose(expected, got)

def test_tree_div(self):
expected = (self.tree_A[0] / self.tree_B[0], self.tree_A[1] / self.tree_B[1])
got = tree_util.tree_div(self.tree_A, self.tree_B)
self.assertAllClose(expected, got)

got = tree_util.tree_div(self.tree_A_dict, self.tree_B_dict)
expected = (1.0, {'k1': 0.5, 'k2': (0.333333333, 0.25)}, 0.2)
self.assertAllClose(expected, got)

def test_tree_sum(self):
expected = jnp.sum(self.array_A)
got = tree_util.tree_sum(self.array_A)
Expand Down

0 comments on commit ea151c1

Please sign in to comment.