Skip to content

Commit

Permalink
expanded verbose for all algorithms, made verbose a boolean for all a…
Browse files Browse the repository at this point in the history
…lgorithms
  • Loading branch information
vroulet committed Nov 22, 2023
1 parent 48802ae commit 9e173ac
Show file tree
Hide file tree
Showing 28 changed files with 309 additions and 127 deletions.
7 changes: 5 additions & 2 deletions jaxopt/_src/anderson.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Callable
from typing import NamedTuple
from typing import List
from typing import Union

from typing import Optional
from dataclasses import dataclass
Expand Down Expand Up @@ -134,7 +135,7 @@ class AndersonAcceleration(base.IterativeSolver):
has_aux: wether fixed_point_fun returns additional data. (default: False)
This additional data is not taken into account by the fixed point.
The solver returns the `aux` associated to the last iterate (i.e the fixed point).
verbose: whether to print error on every iteration or not.
verbose: whether to print information on every iteration or not.
implicit_diff: whether to enable implicit diff or autodiff of unrolled
iterations.
implicit_diff_solve: the linear system solver to use.
Expand All @@ -154,7 +155,7 @@ class AndersonAcceleration(base.IterativeSolver):
tol: float = 1e-5
ridge: float = 1e-5
has_aux: bool = False
verbose: bool = False
verbose: Union[bool, int] = False
implicit_diff: bool = True
implicit_diff_solve: Optional[Callable] = None
jit: bool = True
Expand Down Expand Up @@ -246,6 +247,8 @@ def use_param(t):
aux=aux,
num_fun_eval=state.num_fun_eval+1)

if self.verbose:
self.log_info(next_state, error_name="Residual Norm")
return base.OptStep(params=next_params, state=next_state)

def optimality_fun(self, params, *args, **kwargs):
Expand Down
7 changes: 5 additions & 2 deletions jaxopt/_src/anderson_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class AndersonWrapper(base.IterativeSolver):
beta: momentum in Anderson updates. (default: 1).
ridge: ridge regularization in solver.
Consider increasing this value if the solver returns ``NaN``.
verbose: whether to print error on every iteration or not.
verbose: whether to print information on every iteration or not.
implicit_diff: whether to enable implicit diff or autodiff of unrolled
iterations.
implicit_diff_solve: the linear system solver to use.
Expand All @@ -80,7 +80,7 @@ class AndersonWrapper(base.IterativeSolver):
mixing_frequency: int = None
beta: float = 1.
ridge: float = 1e-5
verbose: bool = False
verbose: Union[bool, int] = False
implicit_diff: bool = True
implicit_diff_solve: Optional[Callable] = None
jit: bool = True
Expand Down Expand Up @@ -161,6 +161,9 @@ def use_param(t):
params_history=params_history,
residuals_history=residuals_history,
residual_gram=residual_gram)

if self.verbose:
self.log_info(next_state, error_name="Inner Solver Error")
return base.OptStep(params=next_params, state=next_state)

def optimality_fun(self, params, *args, **kwargs):
Expand Down
14 changes: 12 additions & 2 deletions jaxopt/_src/armijo_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Callable
from typing import NamedTuple
from typing import Optional
from typing import Union

import jax
import jax.lax as lax
Expand Down Expand Up @@ -191,7 +192,7 @@ class ArmijoSGD(base.StochasticSolver):
maxiter: maximum number of solver iterations.
maxls: maximum number of steps in line search.
tol: tolerance to use.
verbose: whether to print error on every iteration or not.
verbose: whether to print information on every iteration or not.
implicit_diff: whether to enable implicit diff or autodiff of unrolled
iterations.
Expand Down Expand Up @@ -224,7 +225,7 @@ class ArmijoSGD(base.StochasticSolver):
maxiter: int = 500
maxls: int = 15
tol: float = 1e-3
verbose: int = 0
verbose: Union[bool, int] = False

implicit_diff: bool = False
implicit_diff_solve: Optional[Callable] = None
Expand Down Expand Up @@ -315,6 +316,15 @@ def update(self, params, state, *args, **kwargs) -> base.OptStep:
stepsize=jnp.asarray(stepsize, dtype=dtype),
velocity=next_velocity)

if self.verbose:
self.log_info(
next_state,
error_name="Gradient Norm",
additional_info={
'Objective Value': next_state.value,
'Stepsize': stepsize
},
)
return base.OptStep(next_params, next_state)

def optimality_fun(self, params, *args, **kwargs):
Expand Down
16 changes: 14 additions & 2 deletions jaxopt/_src/backtracking_linesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class BacktrackingLineSearch(base.IterativeLineSearch):
maxiter: maximum number of line search iterations.
tol: tolerance of the stopping criterion.
verbose: whether to print error on every iteration or not.
verbose: whether to print information on every iteration or not.
jit: whether to JIT-compile the optimization loop (default: "auto").
unroll: whether to unroll the optimization loop (default: "auto").
Expand All @@ -95,7 +95,7 @@ class BacktrackingLineSearch(base.IterativeLineSearch):
decrease_factor: float = 0.8
max_stepsize: float = 1.0

verbose: int = 0
verbose: Union[bool, int] = False
jit: base.AutoOrBoolean = "auto"
unroll: base.AutoOrBoolean = "auto"

Expand Down Expand Up @@ -283,6 +283,18 @@ def update(
num_fun_eval=num_fun_eval,
num_grad_eval=num_grad_eval)

if self.verbose:
additional_info = {'Stepsize': stepsize, 'Objective Value': new_value}
if self.condition != 'armijo':
error_name = "Minimum Decrease & Curvature Errors"
additional_info.update({'Decrease Error': error_cond1})
else:
error_name = "Decrease Error"
self.log_info(
new_state,
error_name=error_name,
additional_info=additional_info
)
return base.LineSearchStep(stepsize=new_stepsize, state=new_state)

def _compute_final_grad(self, params, fun_args, fun_kwargs):
Expand Down
18 changes: 15 additions & 3 deletions jaxopt/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,22 @@ def _get_unroll_option(self):

def _cond_fun(self, inputs):
_, state = inputs[0]
if self.verbose:
name = self.__class__.__name__
jax.debug.print("Solver: %s, Error: {error}" % name, error=state.error)
return state.error > self.tol

def log_info(self, state, error_name='Error', additional_info={}):
"""Base info at the end of the update."""
other_info_kw = ' '.join([key + ":{} " for key in additional_info.keys()])
name = self.__class__.__name__
jax.debug.print(
"INFO: jaxopt." + name + ": " + \
"Iter: {} " + \
error_name + " (stop. crit.): {} " + \
other_info_kw,
state.iter_num,
state.error,
*additional_info.values(),
ordered=True
)

def _body_fun(self, inputs):
(params, state), (args, kwargs) = inputs
Expand Down
18 changes: 15 additions & 3 deletions jaxopt/_src/bfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ class BFGS(base.IterativeSolver):
implicit_diff_solve: the linear system solver to use.
jit: whether to JIT-compile the optimization loop (default: True).
unroll: whether to unroll the optimization loop (default: "auto").
verbose: whether to print error on every iteration or not.
verbose: if set to True or 1 prints the information at each step of
the solver, if set to 2, print also the information of the linesearch.
Reference:
Jorge Nocedal and Stephen Wright.
Expand Down Expand Up @@ -141,7 +142,7 @@ class BFGS(base.IterativeSolver):
jit: bool = True
unroll: base.AutoOrBoolean = "auto"

verbose: bool = False
verbose: Union[bool, int] = False

def init_state(self,
init_params: Any,
Expand Down Expand Up @@ -260,6 +261,17 @@ def update(self,
num_fun_eval=new_num_fun_eval,
num_linesearch_iter=new_num_linesearch_iter)

if self.verbose:
self.log_info(
new_state,
error_name="Gradient Norm",
additional_info={
"Objective Value": new_value,
"Stepsize": new_stepsize,
"Number Linesearch Iterations":
new_state.num_linesearch_iter - state.num_linesearch_iter
}
)
return base.OptStep(params=new_params, state=new_state)

def optimality_fun(self, params, *args, **kwargs):
Expand Down Expand Up @@ -289,7 +301,7 @@ def __post_init__(self):
max_stepsize=self.max_stepsize,
jit=self.jit,
unroll=unroll,
verbose=self.verbose,
verbose=int(self.verbose)-1
)
self.run_ls = self.linesearch_solver.run

Expand Down
14 changes: 12 additions & 2 deletions jaxopt/_src/bisection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Callable
from typing import NamedTuple
from typing import Optional
from typing import Union

from dataclasses import dataclass

Expand Down Expand Up @@ -56,7 +57,7 @@ class Bisection(base.IterativeSolver):
check_bracket: whether to check correctness of the bracketing interval.
If True, the method ``run`` cannot be jitted.
implicit_diff_solve: the linear system solver to use.
verbose: whether to print error on every iteration or not.
verbose: whether to print information on every iteration or not.
jit: whether to JIT-compile the bisection loop (default: True).
unroll: whether to unroll the bisection loop (default: "auto").
Expand All @@ -67,7 +68,7 @@ class Bisection(base.IterativeSolver):
maxiter: int = 30
tol: float = 1e-5
check_bracket: bool = True
verbose: bool = False
verbose: Union[bool, int] = False
implicit_diff_solve: Optional[Callable] = None
has_aux: bool = False
jit: bool = True
Expand Down Expand Up @@ -151,6 +152,15 @@ def update(self,
aux=aux,
num_fun_eval=state.num_fun_eval + 1)

if self.verbose:
self.log_info(
state,
error_name="Absolute Value Output",
additional_info={
"High Point": high,
"Low Point": low
}
)
return base.OptStep(params=params, state=state)

def run(self,
Expand Down
9 changes: 7 additions & 2 deletions jaxopt/_src/block_cd.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class BlockCoordinateDescent(base.IterativeSolver):
maxiter: maximum number of proximal gradient descent iterations.
tol: tolerance to use.
verbose: whether to print error on every iteration or not.
verbose: whether to print information on every iteration or not.
implicit_diff: whether to enable implicit diff or autodiff of unrolled
iterations.
Expand All @@ -77,7 +77,7 @@ class BlockCoordinateDescent(base.IterativeSolver):
block_prox: Callable
maxiter: int = 500
tol: float = 1e-4
verbose: int = 0
verbose: Union[bool, int] = False
implicit_diff: bool = True
implicit_diff_solve: Optional[Callable] = None
jit: bool = True
Expand Down Expand Up @@ -167,6 +167,11 @@ def body_fun(i, tup):
num_grad_eval=state.num_grad_eval + n_for,
num_prox_eval=state.num_prox_eval + n_for)

if self.verbose:
self.log_info(
state,
error_name="Distance btw Iterates"
)
return base.OptStep(params=params, state=state)

def _fixed_point_fun(self, params, hyperparams_prox, *args, **kwargs):
Expand Down
19 changes: 14 additions & 5 deletions jaxopt/_src/broyden.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ class Broyden(base.IterativeSolver):
jit: whether to JIT-compile the optimization loop (default: True).
unroll: whether to unroll the optimization loop (default: "auto").
verbose: whether to print error on every iteration or not.
Warning: verbose=True will automatically disable jit.
verbose: if set to True or 1 prints the information at each step of
the solver, if set to 2, print also the information of the linesearch.
Reference:
Charles G. Broyden.
Expand Down Expand Up @@ -212,12 +212,10 @@ class Broyden(base.IterativeSolver):
jit: bool = True
unroll: base.AutoOrBoolean = "auto"

verbose: bool = False
verbose: Union[bool, int] = False

def _cond_fun(self, inputs):
_, state = inputs[0]
if self.verbose:
jax.debug.print("Solver: Broyden, Error: {error}", error=state.error)
# We continue the optimization loop while the error tolerance is not met and,
# either failed linesearch is disallowed or linesearch hasn't failed.
return (state.error > self.tol) & jnp.logical_or(not self.stop_if_linesearch_fails, ~state.failed_linesearch)
Expand Down Expand Up @@ -330,6 +328,7 @@ def ls_fun_with_aux(params, *args, **kwargs):
jit=self.jit,
unroll=self.unroll,
has_aux=True,
verbose=int(self.verbose)-1,
tol=1e-2)
init_stepsize = jnp.where(state.stepsize <= self.min_stepsize,
# If stepsize became too small, we restart it.
Expand Down Expand Up @@ -382,6 +381,16 @@ def ls_fun_with_aux(params, *args, **kwargs):
num_linesearch_iter=new_num_linesearch_iter,
failed_linesearch=failed_linesearch)

if self.verbose:
self.log_info(
new_state,
error_name="Norm Output",
additional_info={
"Stepsize": new_stepsize,
"Number Linesearch Iterations":
new_state.num_linesearch_iter - state.num_linesearch_iter
}
)
return base.OptStep(params=new_params, state=new_state)

def optimality_fun(self, params, *args, **kwargs):
Expand Down
7 changes: 5 additions & 2 deletions jaxopt/_src/cd_qp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Callable
from typing import NamedTuple
from typing import Optional
from typing import Union

from dataclasses import dataclass

Expand Down Expand Up @@ -62,7 +63,7 @@ class BoxCDQP(base.IterativeSolver):
Attributes:
maxiter: maximum number of coordinate descent iterations.
tol: tolerance to use.
verbose: whether to print error on every iteration or not.
verbose: whether to print information on every iteration or not.
implicit_diff: whether to enable implicit diff or autodiff of unrolled
iterations.
Expand All @@ -73,7 +74,7 @@ class BoxCDQP(base.IterativeSolver):
"""
maxiter: int = 500
tol: float = 1e-4
verbose: int = 0
verbose: Union[bool, int] = False
implicit_diff: bool = True
implicit_diff_solve: Optional[Callable] = None
jit: bool = True
Expand Down Expand Up @@ -124,6 +125,8 @@ def update(self,

state = BoxCDQPState(iter_num=state.iter_num + 1, error=error)

if self.verbose:
self.log_info(state)
return base.OptStep(params=params, state=state)

def _fixed_point_fun(self,
Expand Down
Loading

0 comments on commit 9e173ac

Please sign in to comment.