Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expanded verbose for all algorithms, made verbose a boolean for all algorithms #544

Merged
merged 1 commit into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions jaxopt/_src/anderson.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Callable
from typing import NamedTuple
from typing import List
from typing import Union

from typing import Optional
from dataclasses import dataclass
Expand Down Expand Up @@ -134,7 +135,7 @@ class AndersonAcceleration(base.IterativeSolver):
has_aux: wether fixed_point_fun returns additional data. (default: False)
This additional data is not taken into account by the fixed point.
The solver returns the `aux` associated to the last iterate (i.e the fixed point).
verbose: whether to print error on every iteration or not.
verbose: whether to print information on every iteration or not.
implicit_diff: whether to enable implicit diff or autodiff of unrolled
iterations.
implicit_diff_solve: the linear system solver to use.
Expand All @@ -154,7 +155,7 @@ class AndersonAcceleration(base.IterativeSolver):
tol: float = 1e-5
ridge: float = 1e-5
has_aux: bool = False
verbose: bool = False
verbose: Union[bool, int] = False
implicit_diff: bool = True
implicit_diff_solve: Optional[Callable] = None
jit: bool = True
Expand Down Expand Up @@ -246,6 +247,8 @@ def use_param(t):
aux=aux,
num_fun_eval=state.num_fun_eval+1)

if self.verbose:
self.log_info(next_state, error_name="Residual Norm")
return base.OptStep(params=next_params, state=next_state)

def optimality_fun(self, params, *args, **kwargs):
Expand Down
7 changes: 5 additions & 2 deletions jaxopt/_src/anderson_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class AndersonWrapper(base.IterativeSolver):
beta: momentum in Anderson updates. (default: 1).
ridge: ridge regularization in solver.
Consider increasing this value if the solver returns ``NaN``.
verbose: whether to print error on every iteration or not.
verbose: whether to print information on every iteration or not.
implicit_diff: whether to enable implicit diff or autodiff of unrolled
iterations.
implicit_diff_solve: the linear system solver to use.
Expand All @@ -80,7 +80,7 @@ class AndersonWrapper(base.IterativeSolver):
mixing_frequency: int = None
beta: float = 1.
ridge: float = 1e-5
verbose: bool = False
verbose: Union[bool, int] = False
implicit_diff: bool = True
implicit_diff_solve: Optional[Callable] = None
jit: bool = True
Expand Down Expand Up @@ -161,6 +161,9 @@ def use_param(t):
params_history=params_history,
residuals_history=residuals_history,
residual_gram=residual_gram)

if self.verbose:
self.log_info(next_state, error_name="Inner Solver Error")
return base.OptStep(params=next_params, state=next_state)

def optimality_fun(self, params, *args, **kwargs):
Expand Down
14 changes: 12 additions & 2 deletions jaxopt/_src/armijo_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Callable
from typing import NamedTuple
from typing import Optional
from typing import Union

import jax
import jax.lax as lax
Expand Down Expand Up @@ -191,7 +192,7 @@ class ArmijoSGD(base.StochasticSolver):
maxiter: maximum number of solver iterations.
maxls: maximum number of steps in line search.
tol: tolerance to use.
verbose: whether to print error on every iteration or not.
verbose: whether to print information on every iteration or not.

implicit_diff: whether to enable implicit diff or autodiff of unrolled
iterations.
Expand Down Expand Up @@ -224,7 +225,7 @@ class ArmijoSGD(base.StochasticSolver):
maxiter: int = 500
maxls: int = 15
tol: float = 1e-3
verbose: int = 0
verbose: Union[bool, int] = False

implicit_diff: bool = False
implicit_diff_solve: Optional[Callable] = None
Expand Down Expand Up @@ -315,6 +316,15 @@ def update(self, params, state, *args, **kwargs) -> base.OptStep:
stepsize=jnp.asarray(stepsize, dtype=dtype),
velocity=next_velocity)

if self.verbose:
self.log_info(
next_state,
error_name="Gradient Norm",
additional_info={
'Objective Value': next_state.value,
'Stepsize': stepsize
},
)
return base.OptStep(next_params, next_state)

def optimality_fun(self, params, *args, **kwargs):
Expand Down
16 changes: 14 additions & 2 deletions jaxopt/_src/backtracking_linesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class BacktrackingLineSearch(base.IterativeLineSearch):
maxiter: maximum number of line search iterations.
tol: tolerance of the stopping criterion.

verbose: whether to print error on every iteration or not.
verbose: whether to print information on every iteration or not.

jit: whether to JIT-compile the optimization loop (default: "auto").
unroll: whether to unroll the optimization loop (default: "auto").
Expand All @@ -95,7 +95,7 @@ class BacktrackingLineSearch(base.IterativeLineSearch):
decrease_factor: float = 0.8
max_stepsize: float = 1.0

verbose: int = 0
verbose: Union[bool, int] = False
jit: base.AutoOrBoolean = "auto"
unroll: base.AutoOrBoolean = "auto"

Expand Down Expand Up @@ -283,6 +283,18 @@ def update(
num_fun_eval=num_fun_eval,
num_grad_eval=num_grad_eval)

if self.verbose:
additional_info = {'Stepsize': stepsize, 'Objective Value': new_value}
if self.condition != 'armijo':
error_name = "Minimum Decrease & Curvature Errors"
additional_info.update({'Decrease Error': error_cond1})
else:
error_name = "Decrease Error"
self.log_info(
new_state,
error_name=error_name,
additional_info=additional_info
)
return base.LineSearchStep(stepsize=new_stepsize, state=new_state)

def _compute_final_grad(self, params, fun_args, fun_kwargs):
Expand Down
18 changes: 15 additions & 3 deletions jaxopt/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,22 @@ def _get_unroll_option(self):

def _cond_fun(self, inputs):
_, state = inputs[0]
if self.verbose:
name = self.__class__.__name__
jax.debug.print("Solver: %s, Error: {error}" % name, error=state.error)
return state.error > self.tol

def log_info(self, state, error_name='Error', additional_info={}):
"""Base info at the end of the update."""
other_info_kw = ' '.join([key + ":{} " for key in additional_info.keys()])
name = self.__class__.__name__
jax.debug.print(
"INFO: jaxopt." + name + ": " + \
"Iter: {} " + \
error_name + " (stop. crit.): {} " + \
other_info_kw,
state.iter_num,
state.error,
*additional_info.values(),
ordered=True
)

def _body_fun(self, inputs):
(params, state), (args, kwargs) = inputs
Expand Down
18 changes: 15 additions & 3 deletions jaxopt/_src/bfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ class BFGS(base.IterativeSolver):
implicit_diff_solve: the linear system solver to use.
jit: whether to JIT-compile the optimization loop (default: True).
unroll: whether to unroll the optimization loop (default: "auto").
verbose: whether to print error on every iteration or not.
verbose: if set to True or 1 prints the information at each step of
the solver, if set to 2, print also the information of the linesearch.

Reference:
Jorge Nocedal and Stephen Wright.
Expand Down Expand Up @@ -141,7 +142,7 @@ class BFGS(base.IterativeSolver):
jit: bool = True
unroll: base.AutoOrBoolean = "auto"

verbose: bool = False
verbose: Union[bool, int] = False

def init_state(self,
init_params: Any,
Expand Down Expand Up @@ -260,6 +261,17 @@ def update(self,
num_fun_eval=new_num_fun_eval,
num_linesearch_iter=new_num_linesearch_iter)

if self.verbose:
self.log_info(
new_state,
error_name="Gradient Norm",
additional_info={
"Objective Value": new_value,
"Stepsize": new_stepsize,
"Number Linesearch Iterations":
new_state.num_linesearch_iter - state.num_linesearch_iter
}
)
return base.OptStep(params=new_params, state=new_state)

def optimality_fun(self, params, *args, **kwargs):
Expand Down Expand Up @@ -289,7 +301,7 @@ def __post_init__(self):
max_stepsize=self.max_stepsize,
jit=self.jit,
unroll=unroll,
verbose=self.verbose,
verbose=int(self.verbose)-1
)
self.run_ls = self.linesearch_solver.run

Expand Down
14 changes: 12 additions & 2 deletions jaxopt/_src/bisection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Callable
from typing import NamedTuple
from typing import Optional
from typing import Union

from dataclasses import dataclass

Expand Down Expand Up @@ -56,7 +57,7 @@ class Bisection(base.IterativeSolver):
check_bracket: whether to check correctness of the bracketing interval.
If True, the method ``run`` cannot be jitted.
implicit_diff_solve: the linear system solver to use.
verbose: whether to print error on every iteration or not.
verbose: whether to print information on every iteration or not.
jit: whether to JIT-compile the bisection loop (default: True).
unroll: whether to unroll the bisection loop (default: "auto").

Expand All @@ -67,7 +68,7 @@ class Bisection(base.IterativeSolver):
maxiter: int = 30
tol: float = 1e-5
check_bracket: bool = True
verbose: bool = False
verbose: Union[bool, int] = False
implicit_diff_solve: Optional[Callable] = None
has_aux: bool = False
jit: bool = True
Expand Down Expand Up @@ -151,6 +152,15 @@ def update(self,
aux=aux,
num_fun_eval=state.num_fun_eval + 1)

if self.verbose:
self.log_info(
state,
error_name="Absolute Value Output",
additional_info={
"High Point": high,
"Low Point": low
}
)
return base.OptStep(params=params, state=state)

def run(self,
Expand Down
9 changes: 7 additions & 2 deletions jaxopt/_src/block_cd.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class BlockCoordinateDescent(base.IterativeSolver):

maxiter: maximum number of proximal gradient descent iterations.
tol: tolerance to use.
verbose: whether to print error on every iteration or not.
verbose: whether to print information on every iteration or not.

implicit_diff: whether to enable implicit diff or autodiff of unrolled
iterations.
Expand All @@ -77,7 +77,7 @@ class BlockCoordinateDescent(base.IterativeSolver):
block_prox: Callable
maxiter: int = 500
tol: float = 1e-4
verbose: int = 0
verbose: Union[bool, int] = False
implicit_diff: bool = True
implicit_diff_solve: Optional[Callable] = None
jit: bool = True
Expand Down Expand Up @@ -167,6 +167,11 @@ def body_fun(i, tup):
num_grad_eval=state.num_grad_eval + n_for,
num_prox_eval=state.num_prox_eval + n_for)

if self.verbose:
self.log_info(
state,
error_name="Distance btw Iterates"
)
return base.OptStep(params=params, state=state)

def _fixed_point_fun(self, params, hyperparams_prox, *args, **kwargs):
Expand Down
19 changes: 14 additions & 5 deletions jaxopt/_src/broyden.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ class Broyden(base.IterativeSolver):
jit: whether to JIT-compile the optimization loop (default: True).
unroll: whether to unroll the optimization loop (default: "auto").

verbose: whether to print error on every iteration or not.
Warning: verbose=True will automatically disable jit.
verbose: if set to True or 1 prints the information at each step of
the solver, if set to 2, print also the information of the linesearch.

Reference:
Charles G. Broyden.
Expand Down Expand Up @@ -212,12 +212,10 @@ class Broyden(base.IterativeSolver):
jit: bool = True
unroll: base.AutoOrBoolean = "auto"

verbose: bool = False
verbose: Union[bool, int] = False

def _cond_fun(self, inputs):
_, state = inputs[0]
if self.verbose:
jax.debug.print("Solver: Broyden, Error: {error}", error=state.error)
# We continue the optimization loop while the error tolerance is not met and,
# either failed linesearch is disallowed or linesearch hasn't failed.
return (state.error > self.tol) & jnp.logical_or(not self.stop_if_linesearch_fails, ~state.failed_linesearch)
Expand Down Expand Up @@ -330,6 +328,7 @@ def ls_fun_with_aux(params, *args, **kwargs):
jit=self.jit,
unroll=self.unroll,
has_aux=True,
verbose=int(self.verbose)-1,
tol=1e-2)
init_stepsize = jnp.where(state.stepsize <= self.min_stepsize,
# If stepsize became too small, we restart it.
Expand Down Expand Up @@ -382,6 +381,16 @@ def ls_fun_with_aux(params, *args, **kwargs):
num_linesearch_iter=new_num_linesearch_iter,
failed_linesearch=failed_linesearch)

if self.verbose:
self.log_info(
new_state,
error_name="Norm Output",
additional_info={
"Stepsize": new_stepsize,
"Number Linesearch Iterations":
new_state.num_linesearch_iter - state.num_linesearch_iter
}
)
return base.OptStep(params=new_params, state=new_state)

def optimality_fun(self, params, *args, **kwargs):
Expand Down
7 changes: 5 additions & 2 deletions jaxopt/_src/cd_qp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Callable
from typing import NamedTuple
from typing import Optional
from typing import Union

from dataclasses import dataclass

Expand Down Expand Up @@ -62,7 +63,7 @@ class BoxCDQP(base.IterativeSolver):
Attributes:
maxiter: maximum number of coordinate descent iterations.
tol: tolerance to use.
verbose: whether to print error on every iteration or not.
verbose: whether to print information on every iteration or not.

implicit_diff: whether to enable implicit diff or autodiff of unrolled
iterations.
Expand All @@ -73,7 +74,7 @@ class BoxCDQP(base.IterativeSolver):
"""
maxiter: int = 500
tol: float = 1e-4
verbose: int = 0
verbose: Union[bool, int] = False
implicit_diff: bool = True
implicit_diff_solve: Optional[Callable] = None
jit: bool = True
Expand Down Expand Up @@ -124,6 +125,8 @@ def update(self,

state = BoxCDQPState(iter_num=state.iter_num + 1, error=error)

if self.verbose:
self.log_info(state)
return base.OptStep(params=params, state=state)

def _fixed_point_fun(self,
Expand Down
Loading