From 9e173ac603da9c03f96a5868bc8f3f3fc18be716 Mon Sep 17 00:00:00 2001 From: Vincent Roulet Date: Fri, 6 Oct 2023 11:12:38 -0700 Subject: [PATCH] expanded verbose for all algorithms, made verbose a boolean for all algorithms --- jaxopt/_src/anderson.py | 7 ++- jaxopt/_src/anderson_wrapper.py | 7 ++- jaxopt/_src/armijo_sgd.py | 14 ++++- jaxopt/_src/backtracking_linesearch.py | 16 ++++- jaxopt/_src/base.py | 18 +++++- jaxopt/_src/bfgs.py | 18 +++++- jaxopt/_src/bisection.py | 14 ++++- jaxopt/_src/block_cd.py | 9 ++- jaxopt/_src/broyden.py | 19 ++++-- jaxopt/_src/cd_qp.py | 7 ++- jaxopt/_src/fixed_point_iteration.py | 11 +++- jaxopt/_src/gauss_newton.py | 14 ++++- jaxopt/_src/gradient_descent.py | 2 +- jaxopt/_src/hager_zhang_linesearch.py | 14 ++++- jaxopt/_src/iterative_refinement.py | 6 +- jaxopt/_src/lbfgs.py | 18 ++++-- jaxopt/_src/lbfgsb.py | 21 +++++-- jaxopt/_src/levenberg_marquardt.py | 16 ++++- jaxopt/_src/mirror_descent.py | 10 +++- jaxopt/_src/nonlinear_cg.py | 20 +++++-- jaxopt/_src/optax_wrapper.py | 12 +++- jaxopt/_src/osqp.py | 16 ++--- jaxopt/_src/polyak_sgd.py | 15 ++++- jaxopt/_src/projected_gradient.py | 4 +- jaxopt/_src/proximal_gradient.py | 13 +++- jaxopt/_src/zoom_linesearch.py | 83 +++++++++++++------------- tests/common_test.py | 30 +++++----- tests/zoom_linesearch_test.py | 2 +- 28 files changed, 309 insertions(+), 127 deletions(-) diff --git a/jaxopt/_src/anderson.py b/jaxopt/_src/anderson.py index 84012eac..15f0c114 100644 --- a/jaxopt/_src/anderson.py +++ b/jaxopt/_src/anderson.py @@ -18,6 +18,7 @@ from typing import Callable from typing import NamedTuple from typing import List +from typing import Union from typing import Optional from dataclasses import dataclass @@ -134,7 +135,7 @@ class AndersonAcceleration(base.IterativeSolver): has_aux: wether fixed_point_fun returns additional data. (default: False) This additional data is not taken into account by the fixed point. The solver returns the `aux` associated to the last iterate (i.e the fixed point). - verbose: whether to print error on every iteration or not. + verbose: whether to print information on every iteration or not. implicit_diff: whether to enable implicit diff or autodiff of unrolled iterations. implicit_diff_solve: the linear system solver to use. @@ -154,7 +155,7 @@ class AndersonAcceleration(base.IterativeSolver): tol: float = 1e-5 ridge: float = 1e-5 has_aux: bool = False - verbose: bool = False + verbose: Union[bool, int] = False implicit_diff: bool = True implicit_diff_solve: Optional[Callable] = None jit: bool = True @@ -246,6 +247,8 @@ def use_param(t): aux=aux, num_fun_eval=state.num_fun_eval+1) + if self.verbose: + self.log_info(next_state, error_name="Residual Norm") return base.OptStep(params=next_params, state=next_state) def optimality_fun(self, params, *args, **kwargs): diff --git a/jaxopt/_src/anderson_wrapper.py b/jaxopt/_src/anderson_wrapper.py index 6ffc0af1..dea00a0e 100644 --- a/jaxopt/_src/anderson_wrapper.py +++ b/jaxopt/_src/anderson_wrapper.py @@ -68,7 +68,7 @@ class AndersonWrapper(base.IterativeSolver): beta: momentum in Anderson updates. (default: 1). ridge: ridge regularization in solver. Consider increasing this value if the solver returns ``NaN``. - verbose: whether to print error on every iteration or not. + verbose: whether to print information on every iteration or not. implicit_diff: whether to enable implicit diff or autodiff of unrolled iterations. implicit_diff_solve: the linear system solver to use. @@ -80,7 +80,7 @@ class AndersonWrapper(base.IterativeSolver): mixing_frequency: int = None beta: float = 1. ridge: float = 1e-5 - verbose: bool = False + verbose: Union[bool, int] = False implicit_diff: bool = True implicit_diff_solve: Optional[Callable] = None jit: bool = True @@ -161,6 +161,9 @@ def use_param(t): params_history=params_history, residuals_history=residuals_history, residual_gram=residual_gram) + + if self.verbose: + self.log_info(next_state, error_name="Inner Solver Error") return base.OptStep(params=next_params, state=next_state) def optimality_fun(self, params, *args, **kwargs): diff --git a/jaxopt/_src/armijo_sgd.py b/jaxopt/_src/armijo_sgd.py index 37475c14..c9b79b93 100644 --- a/jaxopt/_src/armijo_sgd.py +++ b/jaxopt/_src/armijo_sgd.py @@ -21,6 +21,7 @@ from typing import Callable from typing import NamedTuple from typing import Optional +from typing import Union import jax import jax.lax as lax @@ -191,7 +192,7 @@ class ArmijoSGD(base.StochasticSolver): maxiter: maximum number of solver iterations. maxls: maximum number of steps in line search. tol: tolerance to use. - verbose: whether to print error on every iteration or not. + verbose: whether to print information on every iteration or not. implicit_diff: whether to enable implicit diff or autodiff of unrolled iterations. @@ -224,7 +225,7 @@ class ArmijoSGD(base.StochasticSolver): maxiter: int = 500 maxls: int = 15 tol: float = 1e-3 - verbose: int = 0 + verbose: Union[bool, int] = False implicit_diff: bool = False implicit_diff_solve: Optional[Callable] = None @@ -315,6 +316,15 @@ def update(self, params, state, *args, **kwargs) -> base.OptStep: stepsize=jnp.asarray(stepsize, dtype=dtype), velocity=next_velocity) + if self.verbose: + self.log_info( + next_state, + error_name="Gradient Norm", + additional_info={ + 'Objective Value': next_state.value, + 'Stepsize': stepsize + }, + ) return base.OptStep(next_params, next_state) def optimality_fun(self, params, *args, **kwargs): diff --git a/jaxopt/_src/backtracking_linesearch.py b/jaxopt/_src/backtracking_linesearch.py index 189e3df5..72db4722 100644 --- a/jaxopt/_src/backtracking_linesearch.py +++ b/jaxopt/_src/backtracking_linesearch.py @@ -77,7 +77,7 @@ class BacktrackingLineSearch(base.IterativeLineSearch): maxiter: maximum number of line search iterations. tol: tolerance of the stopping criterion. - verbose: whether to print error on every iteration or not. + verbose: whether to print information on every iteration or not. jit: whether to JIT-compile the optimization loop (default: "auto"). unroll: whether to unroll the optimization loop (default: "auto"). @@ -95,7 +95,7 @@ class BacktrackingLineSearch(base.IterativeLineSearch): decrease_factor: float = 0.8 max_stepsize: float = 1.0 - verbose: int = 0 + verbose: Union[bool, int] = False jit: base.AutoOrBoolean = "auto" unroll: base.AutoOrBoolean = "auto" @@ -283,6 +283,18 @@ def update( num_fun_eval=num_fun_eval, num_grad_eval=num_grad_eval) + if self.verbose: + additional_info = {'Stepsize': stepsize, 'Objective Value': new_value} + if self.condition != 'armijo': + error_name = "Minimum Decrease & Curvature Errors" + additional_info.update({'Decrease Error': error_cond1}) + else: + error_name = "Decrease Error" + self.log_info( + new_state, + error_name=error_name, + additional_info=additional_info + ) return base.LineSearchStep(stepsize=new_stepsize, state=new_state) def _compute_final_grad(self, params, fun_args, fun_kwargs): diff --git a/jaxopt/_src/base.py b/jaxopt/_src/base.py index e9aa2037..e3b7a507 100644 --- a/jaxopt/_src/base.py +++ b/jaxopt/_src/base.py @@ -264,10 +264,22 @@ def _get_unroll_option(self): def _cond_fun(self, inputs): _, state = inputs[0] - if self.verbose: - name = self.__class__.__name__ - jax.debug.print("Solver: %s, Error: {error}" % name, error=state.error) return state.error > self.tol + + def log_info(self, state, error_name='Error', additional_info={}): + """Base info at the end of the update.""" + other_info_kw = ' '.join([key + ":{} " for key in additional_info.keys()]) + name = self.__class__.__name__ + jax.debug.print( + "INFO: jaxopt." + name + ": " + \ + "Iter: {} " + \ + error_name + " (stop. crit.): {} " + \ + other_info_kw, + state.iter_num, + state.error, + *additional_info.values(), + ordered=True + ) def _body_fun(self, inputs): (params, state), (args, kwargs) = inputs diff --git a/jaxopt/_src/bfgs.py b/jaxopt/_src/bfgs.py index 5bc6c295..d7b6939b 100644 --- a/jaxopt/_src/bfgs.py +++ b/jaxopt/_src/bfgs.py @@ -107,7 +107,8 @@ class BFGS(base.IterativeSolver): implicit_diff_solve: the linear system solver to use. jit: whether to JIT-compile the optimization loop (default: True). unroll: whether to unroll the optimization loop (default: "auto"). - verbose: whether to print error on every iteration or not. + verbose: if set to True or 1 prints the information at each step of + the solver, if set to 2, print also the information of the linesearch. Reference: Jorge Nocedal and Stephen Wright. @@ -141,7 +142,7 @@ class BFGS(base.IterativeSolver): jit: bool = True unroll: base.AutoOrBoolean = "auto" - verbose: bool = False + verbose: Union[bool, int] = False def init_state(self, init_params: Any, @@ -260,6 +261,17 @@ def update(self, num_fun_eval=new_num_fun_eval, num_linesearch_iter=new_num_linesearch_iter) + if self.verbose: + self.log_info( + new_state, + error_name="Gradient Norm", + additional_info={ + "Objective Value": new_value, + "Stepsize": new_stepsize, + "Number Linesearch Iterations": + new_state.num_linesearch_iter - state.num_linesearch_iter + } + ) return base.OptStep(params=new_params, state=new_state) def optimality_fun(self, params, *args, **kwargs): @@ -289,7 +301,7 @@ def __post_init__(self): max_stepsize=self.max_stepsize, jit=self.jit, unroll=unroll, - verbose=self.verbose, + verbose=int(self.verbose)-1 ) self.run_ls = self.linesearch_solver.run diff --git a/jaxopt/_src/bisection.py b/jaxopt/_src/bisection.py index a5ce452e..943a20d3 100644 --- a/jaxopt/_src/bisection.py +++ b/jaxopt/_src/bisection.py @@ -18,6 +18,7 @@ from typing import Callable from typing import NamedTuple from typing import Optional +from typing import Union from dataclasses import dataclass @@ -56,7 +57,7 @@ class Bisection(base.IterativeSolver): check_bracket: whether to check correctness of the bracketing interval. If True, the method ``run`` cannot be jitted. implicit_diff_solve: the linear system solver to use. - verbose: whether to print error on every iteration or not. + verbose: whether to print information on every iteration or not. jit: whether to JIT-compile the bisection loop (default: True). unroll: whether to unroll the bisection loop (default: "auto"). @@ -67,7 +68,7 @@ class Bisection(base.IterativeSolver): maxiter: int = 30 tol: float = 1e-5 check_bracket: bool = True - verbose: bool = False + verbose: Union[bool, int] = False implicit_diff_solve: Optional[Callable] = None has_aux: bool = False jit: bool = True @@ -151,6 +152,15 @@ def update(self, aux=aux, num_fun_eval=state.num_fun_eval + 1) + if self.verbose: + self.log_info( + state, + error_name="Absolute Value Output", + additional_info={ + "High Point": high, + "Low Point": low + } + ) return base.OptStep(params=params, state=state) def run(self, diff --git a/jaxopt/_src/block_cd.py b/jaxopt/_src/block_cd.py index 78d8312a..d77eb4d9 100644 --- a/jaxopt/_src/block_cd.py +++ b/jaxopt/_src/block_cd.py @@ -64,7 +64,7 @@ class BlockCoordinateDescent(base.IterativeSolver): maxiter: maximum number of proximal gradient descent iterations. tol: tolerance to use. - verbose: whether to print error on every iteration or not. + verbose: whether to print information on every iteration or not. implicit_diff: whether to enable implicit diff or autodiff of unrolled iterations. @@ -77,7 +77,7 @@ class BlockCoordinateDescent(base.IterativeSolver): block_prox: Callable maxiter: int = 500 tol: float = 1e-4 - verbose: int = 0 + verbose: Union[bool, int] = False implicit_diff: bool = True implicit_diff_solve: Optional[Callable] = None jit: bool = True @@ -167,6 +167,11 @@ def body_fun(i, tup): num_grad_eval=state.num_grad_eval + n_for, num_prox_eval=state.num_prox_eval + n_for) + if self.verbose: + self.log_info( + state, + error_name="Distance btw Iterates" + ) return base.OptStep(params=params, state=state) def _fixed_point_fun(self, params, hyperparams_prox, *args, **kwargs): diff --git a/jaxopt/_src/broyden.py b/jaxopt/_src/broyden.py index bd54222d..df968319 100644 --- a/jaxopt/_src/broyden.py +++ b/jaxopt/_src/broyden.py @@ -177,8 +177,8 @@ class Broyden(base.IterativeSolver): jit: whether to JIT-compile the optimization loop (default: True). unroll: whether to unroll the optimization loop (default: "auto"). - verbose: whether to print error on every iteration or not. - Warning: verbose=True will automatically disable jit. + verbose: if set to True or 1 prints the information at each step of + the solver, if set to 2, print also the information of the linesearch. Reference: Charles G. Broyden. @@ -212,12 +212,10 @@ class Broyden(base.IterativeSolver): jit: bool = True unroll: base.AutoOrBoolean = "auto" - verbose: bool = False + verbose: Union[bool, int] = False def _cond_fun(self, inputs): _, state = inputs[0] - if self.verbose: - jax.debug.print("Solver: Broyden, Error: {error}", error=state.error) # We continue the optimization loop while the error tolerance is not met and, # either failed linesearch is disallowed or linesearch hasn't failed. return (state.error > self.tol) & jnp.logical_or(not self.stop_if_linesearch_fails, ~state.failed_linesearch) @@ -330,6 +328,7 @@ def ls_fun_with_aux(params, *args, **kwargs): jit=self.jit, unroll=self.unroll, has_aux=True, + verbose=int(self.verbose)-1, tol=1e-2) init_stepsize = jnp.where(state.stepsize <= self.min_stepsize, # If stepsize became too small, we restart it. @@ -382,6 +381,16 @@ def ls_fun_with_aux(params, *args, **kwargs): num_linesearch_iter=new_num_linesearch_iter, failed_linesearch=failed_linesearch) + if self.verbose: + self.log_info( + new_state, + error_name="Norm Output", + additional_info={ + "Stepsize": new_stepsize, + "Number Linesearch Iterations": + new_state.num_linesearch_iter - state.num_linesearch_iter + } + ) return base.OptStep(params=new_params, state=new_state) def optimality_fun(self, params, *args, **kwargs): diff --git a/jaxopt/_src/cd_qp.py b/jaxopt/_src/cd_qp.py index b067beb8..bd994522 100644 --- a/jaxopt/_src/cd_qp.py +++ b/jaxopt/_src/cd_qp.py @@ -17,6 +17,7 @@ from typing import Callable from typing import NamedTuple from typing import Optional +from typing import Union from dataclasses import dataclass @@ -62,7 +63,7 @@ class BoxCDQP(base.IterativeSolver): Attributes: maxiter: maximum number of coordinate descent iterations. tol: tolerance to use. - verbose: whether to print error on every iteration or not. + verbose: whether to print information on every iteration or not. implicit_diff: whether to enable implicit diff or autodiff of unrolled iterations. @@ -73,7 +74,7 @@ class BoxCDQP(base.IterativeSolver): """ maxiter: int = 500 tol: float = 1e-4 - verbose: int = 0 + verbose: Union[bool, int] = False implicit_diff: bool = True implicit_diff_solve: Optional[Callable] = None jit: bool = True @@ -124,6 +125,8 @@ def update(self, state = BoxCDQPState(iter_num=state.iter_num + 1, error=error) + if self.verbose: + self.log_info(state) return base.OptStep(params=params, state=state) def _fixed_point_fun(self, diff --git a/jaxopt/_src/fixed_point_iteration.py b/jaxopt/_src/fixed_point_iteration.py index 930ff586..24190449 100644 --- a/jaxopt/_src/fixed_point_iteration.py +++ b/jaxopt/_src/fixed_point_iteration.py @@ -18,6 +18,7 @@ from typing import Callable from typing import NamedTuple from typing import Optional +from typing import Union from dataclasses import dataclass @@ -54,7 +55,7 @@ class FixedPointIteration(base.IterativeSolver): has_aux: wether fixed_point_fun returns additional data. (default: False) if True, the fixed is computed only with respect to first element of the sequence returned. Other elements are carried during computation. - verbose: whether to print error on every iteration or not. + verbose: whether to print information on every iteration or not. implicit_diff: whether to enable implicit diff or autodiff of unrolled iterations. @@ -69,7 +70,7 @@ class FixedPointIteration(base.IterativeSolver): maxiter: int = 100 tol: float = 1e-5 has_aux: bool = False - verbose: bool = False + verbose: Union[bool, int] = False implicit_diff: bool = True implicit_diff_solve: Optional[Callable] = None jit: bool = True @@ -116,6 +117,12 @@ def update(self, error=error, aux=aux, num_fun_eval=state.num_fun_eval + 1) + + if self.verbose: + self.log_info( + next_state, + error_name="Distance btw Iterates" + ) return base.OptStep(params=next_params, state=next_state) def optimality_fun(self, params, *args, **kwargs): diff --git a/jaxopt/_src/gauss_newton.py b/jaxopt/_src/gauss_newton.py index 30a9fbcc..ea1d89cd 100644 --- a/jaxopt/_src/gauss_newton.py +++ b/jaxopt/_src/gauss_newton.py @@ -18,6 +18,7 @@ from typing import Callable from typing import NamedTuple from typing import Optional +from typing import Union from dataclasses import dataclass @@ -57,7 +58,7 @@ class GaussNewton(base.IterativeSolver): iterations. implicit_diff_solve: the linear system solver to use. has_aux: whether ``residual_fun`` outputs auxiliary data or not. - verbose: whether to print error on every iteration or not. + verbose: whether to print information on every iteration or not. jit: whether to JIT-compile the bisection loop (default: True). unroll: whether to unroll the bisection loop (default: "auto"). """ @@ -67,7 +68,7 @@ class GaussNewton(base.IterativeSolver): implicit_diff: bool = True implicit_diff_solve: Optional[Callable] = None has_aux: bool = False - verbose: bool = False + verbose: Union[bool, int] = False jit: bool = True unroll: base.AutoOrBoolean = "auto" @@ -116,15 +117,22 @@ def update(self, delta_params = linear_solve.solve_cg(matvec, gradient) params = tree_sub(params, delta_params) + value = 0.5 * jnp.sum(jnp.square(residual)) state = GaussNewtonState(iter_num=state.iter_num + 1, error=tree_l2_norm(delta_params), residual=residual, - value=0.5 * jnp.sum(jnp.square(residual)), + value=value, delta=delta_params, gradient=gradient, aux=aux) + if self.verbose: + self.log_info( + state, + error_name="Norm GN Update", + additional_info={"Objective Value": value} + ) return base.OptStep(params=params, state=state) def __post_init__(self): diff --git a/jaxopt/_src/gradient_descent.py b/jaxopt/_src/gradient_descent.py index 66963788..bbaf100f 100644 --- a/jaxopt/_src/gradient_descent.py +++ b/jaxopt/_src/gradient_descent.py @@ -54,7 +54,7 @@ class GradientDescent(ProximalGradient): tol: tolerance to use. acceleration: whether to use acceleration (also known as FISTA) or not. - verbose: whether to print error on every iteration or not. + verbose: whether to print information on every iteration or not. implicit_diff: whether to enable implicit diff or autodiff of unrolled iterations. diff --git a/jaxopt/_src/hager_zhang_linesearch.py b/jaxopt/_src/hager_zhang_linesearch.py index 56c0bce9..6d03cd92 100644 --- a/jaxopt/_src/hager_zhang_linesearch.py +++ b/jaxopt/_src/hager_zhang_linesearch.py @@ -87,7 +87,7 @@ class HagerZhangLineSearch(base.IterativeLineSearch): maxiter: maximum number of line search iterations. tol: tolerance of the stopping criterion. - verbose: whether to print error on every iteration or not. + verbose: whether to print information on every iteration or not. jit: whether to JIT-compile the optimization loop (default: "auto"). unroll: whether to unroll the optimization loop (default: "auto"). @@ -107,7 +107,7 @@ class HagerZhangLineSearch(base.IterativeLineSearch): # TODO(vroulet): remove max_stepsize argument as it is not used max_stepsize: float = 1.0 - verbose: int = 0 + verbose: Union[bool, int] = False jit: base.AutoOrBoolean = "auto" unroll: base.AutoOrBoolean = "auto" @@ -554,6 +554,16 @@ def _reupdate(): num_fun_eval=new_num_fun_eval, num_grad_eval=new_num_grad_eval) + if self.verbose: + self.log_info( + new_state, + error_name="Minimum Decrease & Curvature Errors", + additional_info={ + "Stepsize": new_stepsize, + "Objective Value": new_value + } + ) + return base.LineSearchStep(stepsize=new_stepsize, state=new_state) def __post_init__(self): diff --git a/jaxopt/_src/iterative_refinement.py b/jaxopt/_src/iterative_refinement.py index 6cc114fa..70667aad 100644 --- a/jaxopt/_src/iterative_refinement.py +++ b/jaxopt/_src/iterative_refinement.py @@ -98,7 +98,7 @@ class IterativeRefinement(base.IterativeSolver): This solver can be inaccurate and run with low precision. maxiter: maximum number of iterations (default: 10). tol: absolute tolerance for stoping criterion (default: 1e-7). - verbose: If verbose=1, print error at each iteration. + verbose: whether to print information on every iteration or not. implicit_diff: whether to enable implicit diff or autodiff of unrolled iterations. implicit_diff_solve: the linear system solver to use. @@ -118,7 +118,7 @@ class IterativeRefinement(base.IterativeSolver): solve: Callable = partial(linear_solve.solve_gmres, ridge=1e-6) maxiter: int = 10 tol: float = 1e-7 - verbose: int = 0 + verbose: Union[bool, int] = False implicit_diff_solve: Optional[Callable] = None jit: bool = True unroll: base.AutoOrBoolean = "auto" @@ -175,6 +175,8 @@ def update(self, num_matvec_bar_eval=state.num_matvec_bar_eval + 1, num_solve_eval=state.num_solve_eval + 1) + if self.verbose: + self.log_info(state, error_name="Residual Norm") return base.OptStep(params, state) def run(self, diff --git a/jaxopt/_src/lbfgs.py b/jaxopt/_src/lbfgs.py index f80a5e7b..1be37e1b 100644 --- a/jaxopt/_src/lbfgs.py +++ b/jaxopt/_src/lbfgs.py @@ -195,7 +195,8 @@ class LBFGS(base.IterativeSolver): implicit_diff_solve: the linear system solver to use. jit: whether to JIT-compile the optimization loop (default: True). unroll: whether to unroll the optimization loop (default: "auto"). - verbose: whether to print error on every iteration or not. + verbose: if set to True or 1 prints the information at each step of + the solver, if set to 2, print also the information of the linesearch. References: Jorge Nocedal and Stephen Wright. @@ -237,12 +238,10 @@ class LBFGS(base.IterativeSolver): jit: bool = True unroll: base.AutoOrBoolean = "auto" - verbose: bool = False + verbose: Union[bool, int] = False def _cond_fun(self, inputs): _, state = inputs[0] - if self.verbose: - jax.debug.print("Solver: LBFGS, Error: {error}", error=state.error) # We continue the optimization loop while the error tolerance is not met and, # either failed linesearch is disallowed or linesearch hasn't failed. return (state.error > self.tol) & jnp.logical_or(not self.stop_if_linesearch_fails, ~state.failed_linesearch) @@ -405,6 +404,17 @@ def update(self, num_fun_eval=new_num_fun_eval, num_linesearch_iter=new_num_linesearch_iter) + if self.verbose: + self.log_info( + new_state, + error_name="Gradient Norm", + additional_info={ + "Objective Value": new_value, + "Stepsize": new_stepsize, + "Number Linesearch Iterations": + new_state.num_linesearch_iter - state.num_linesearch_iter + } + ) return base.OptStep(params=new_params, state=new_state) def optimality_fun(self, params, *args, **kwargs): diff --git a/jaxopt/_src/lbfgsb.py b/jaxopt/_src/lbfgsb.py index 4a512437..030e7082 100644 --- a/jaxopt/_src/lbfgsb.py +++ b/jaxopt/_src/lbfgsb.py @@ -273,8 +273,8 @@ class LBFGSB(base.IterativeSolver): implicit_diff_solve: the linear system solver to use. jit: whether to JIT-compile the optimization loop (default: True). unroll: whether to unroll the optimization loop (default: "auto"). - verbose: whether to print error on every iteration or not. - Warning: verbose=True will automatically disable jit. + verbose: if set to True or 1 prints the information at each step of + the solver, if set to 2, print also the information of the linesearch. """ fun: Callable # pylint: disable=g-bare-generic @@ -306,12 +306,10 @@ class LBFGSB(base.IterativeSolver): jit: bool = True unroll: base.AutoOrBoolean = "auto" - verbose: bool = False + verbose: Union[bool, int] = False def _cond_fun(self, inputs): _, state = inputs[0] - if self.verbose: - print(self.__class__.__name__ + " error:", state.error) # We continue the optimization loop while the error tolerance is not met # and either failed linesearch is disallowed or linesearch hasn't failed. return (state.error > self.tol) & jnp.logical_or( @@ -558,6 +556,17 @@ def update( num_linesearch_iter=new_num_linesearch_iter, ) + if self.verbose: + self.log_info( + new_state, + error_name="Projected Gradient Norm", + additional_info={ + "Objective Value": new_value, + "Stepsize": new_stepsize, + "Number Linesearch Iterations": + new_state.num_linesearch_iter - state.num_linesearch_iter + } + ) return base.OptStep(new_params, new_state) def _fixed_point_fun(self, sol, bounds, args, kwargs): @@ -609,7 +618,7 @@ def __post_init__(self): max_stepsize=self.max_stepsize, jit=self.jit, unroll=unroll, - verbose=self.verbose, + verbose=int(self.verbose)-1, ) self.run_ls = linesearch_solver.run diff --git a/jaxopt/_src/levenberg_marquardt.py b/jaxopt/_src/levenberg_marquardt.py index 5ded2f1f..1a3646ac 100644 --- a/jaxopt/_src/levenberg_marquardt.py +++ b/jaxopt/_src/levenberg_marquardt.py @@ -122,7 +122,7 @@ class LevenbergMarquardt(base.IterativeSolver): contribution_ratio_threshold: float, the threshold for acceleration/velocity ratio. We update the parameters in the algorithm only if the ratio is smaller than this threshold value. - verbose: bool, whether to print error on every iteration or not. + verbose: bool, whether to print information on every iteration or not. jac_fun: Callable, a function to calculate the Jacobian. If not None, this function is used instead of directly calculating it using ``jax.jacfwd``. materialize_jac: bool, whether to materialize Jacobian. If this option is @@ -156,7 +156,7 @@ class LevenbergMarquardt(base.IterativeSolver): geodesic: bool = False contribution_ratio_threshold = 0.75 - verbose: bool = False + verbose: Union[bool, int] = False jac_fun: Optional[Callable[..., jnp.ndarray]] = None materialize_jac: bool = False implicit_diff: bool = True @@ -433,13 +433,14 @@ def update(self, params, state: NamedTuple, *args, **kwargs) -> base.OptStep: state.jac, state.jt, state.jtj, state.hess_res, state.aux, *args, **kwargs)) + new_value = 0.5 * jnp.sum(jnp.square(residual)) state = LevenbergMarquardtState( iter_num=state.iter_num + 1, damping_factor=damping_factor, increase_factor=increase_factor, error=tree_l2_norm(gradient), residual=residual, - value=0.5 * jnp.sum(jnp.square(residual)), + value=new_value, delta=delta_params, gradient=gradient, jac=jac, @@ -448,6 +449,15 @@ def update(self, params, state: NamedTuple, *args, **kwargs) -> base.OptStep: hess_res=hess_res, aux=aux) + if self.verbose: + self.log_info( + state, + error_name="Gradient Norm", + additional_info={ + "Objective Value": new_value, + "Damping Factor": damping_factor + } + ) return base.OptStep(params=params, state=state) def __post_init__(self): diff --git a/jaxopt/_src/mirror_descent.py b/jaxopt/_src/mirror_descent.py index 5f3b1da1..1e962770 100644 --- a/jaxopt/_src/mirror_descent.py +++ b/jaxopt/_src/mirror_descent.py @@ -67,7 +67,7 @@ class MirrorDescent(base.IterativeSolver): each iteration. maxiter: maximum number of mirror descent iterations. tol: tolerance to use. - verbose: whether to print error on every iteration or not. + verbose: whether to print information on every iteration or not. implicit_diff: whether to enable implicit diff or autodiff of unrolled iterations. implicit_diff_solve: the linear system solver to use. @@ -86,7 +86,7 @@ class MirrorDescent(base.IterativeSolver): stepsize: Union[float, Callable] maxiter: int = 500 tol: float = 1e-2 - verbose: int = 0 + verbose: Union[bool, int] = False implicit_diff: bool = True implicit_diff_solve: Optional[Callable] = None has_aux: bool = False @@ -165,6 +165,12 @@ def _update(self, x, state, hyperparams_proj, args, kwargs): num_fun_eval=state.num_fun_eval + 1, num_grad_eval=state.num_grad_eval + 1, num_proj_eval=state.num_proj_eval + 1,) + + if self.verbose: + self.log_info( + next_state, + error_name="Distance btw Iterates" + ) return base.OptStep(params=next_x, state=next_state) def update(self, diff --git a/jaxopt/_src/nonlinear_cg.py b/jaxopt/_src/nonlinear_cg.py index 0b93250e..398f3db2 100644 --- a/jaxopt/_src/nonlinear_cg.py +++ b/jaxopt/_src/nonlinear_cg.py @@ -16,7 +16,7 @@ import warnings from dataclasses import dataclass -from typing import Any, Callable, NamedTuple, Optional +from typing import Any, Callable, NamedTuple, Optional, Union import jax import jax.numpy as jnp @@ -99,7 +99,8 @@ class NonlinearCG(base.IterativeSolver): jit: whether to JIT-compile the optimization loop (default: True). unroll: whether to unroll the optimization loop (default: "auto"). - verbose: whether to print error on every iteration or not. + verbose: if set to True or 1 prints the information at each step of + the solver, if set to 2, print also the information of the linesearch. References: Jorge Nocedal and Stephen Wright. @@ -135,7 +136,7 @@ class NonlinearCG(base.IterativeSolver): jit: bool = True unroll: base.AutoOrBoolean = "auto" - verbose: int = 0 + verbose: Union[bool, int] = False def init_state(self, init_params: Any, @@ -269,6 +270,17 @@ def update(self, num_grad_eval=new_num_grad_eval, num_linesearch_iter=new_num_linesearch_iter) + if self.verbose: + self.log_info( + new_state, + error_name="Gradient Norm", + additional_info={ + "Objective Value": new_value, + "Stepsize": new_stepsize, + "Number Linesearch Iterations": + new_state.num_linesearch_iter - state.num_linesearch_iter + } + ) return base.OptStep(params=new_params, state=new_state) def optimality_fun(self, params, *args, **kwargs): @@ -302,7 +314,7 @@ def __post_init__(self): max_stepsize=self.max_stepsize, jit=self.jit, unroll=unroll, - verbose=self.verbose + verbose=int(self.verbose)-1 ) self.run_ls = linesearch_solver.run diff --git a/jaxopt/_src/optax_wrapper.py b/jaxopt/_src/optax_wrapper.py index 98b83b78..b17f4a66 100644 --- a/jaxopt/_src/optax_wrapper.py +++ b/jaxopt/_src/optax_wrapper.py @@ -18,6 +18,7 @@ from typing import Callable from typing import NamedTuple from typing import Optional +from typing import Union from dataclasses import dataclass @@ -67,7 +68,7 @@ class OptaxSolver(base.StochasticSolver): maxiter: maximum number of solver iterations. tol: tolerance to use. - verbose: whether to print error on every iteration or not. + verbose: whether to print information on every iteration or not. implicit_diff: whether to enable implicit diff or autodiff of unrolled iterations. @@ -82,7 +83,7 @@ class OptaxSolver(base.StochasticSolver): pre_update: Optional[Callable] = None maxiter: int = 500 tol: float = 1e-3 - verbose: int = 0 + verbose: Union[bool, int] = False implicit_diff: bool = False implicit_diff_solve: Optional[Callable] = None has_aux: bool = False @@ -149,6 +150,13 @@ def update(self, value=jnp.asarray(value), aux=aux, internal_state=opt_state) + + if self.verbose: + self.log_info( + new_state, + error_name="Gradient Norm", + additional_info={"Objective Value": value} + ) return base.OptStep(params=params, state=new_state) def optimality_fun(self, params, *args, **kwargs): diff --git a/jaxopt/_src/osqp.py b/jaxopt/_src/osqp.py index 5c700b41..a62c7f63 100644 --- a/jaxopt/_src/osqp.py +++ b/jaxopt/_src/osqp.py @@ -381,7 +381,9 @@ class BoxOSQP(base.IterativeSolver): tol: absolute tolerance for stoping criterion (default: 1e-3). termination_check_frequency: frequency of termination check. (default: 5). One every `termination_check_frequency` the error is computed. - verbose: If verbose=1, print error at each iteration. If verbose=2, also print stepsizes and primal/dual variables. + verbose: If verbose=1 or True, print error at each iteration. + If verbose=2, also print stepsizes and primal/dual variables. + If verbose=3, also print primal and dual residuals. implicit_diff: whether to enable implicit diff or autodiff of unrolled iterations. implicit_diff_solve: the linear system solver to use. jit: whether to JIT-compile the optimization loop (default: True). @@ -417,7 +419,7 @@ class BoxOSQP(base.IterativeSolver): maxiter: int = 4000 tol: float = 1e-3 termination_check_frequency: int = 5 - verbose: int = 0 + verbose: Union[bool, int] = False implicit_diff: bool = True implicit_diff_solve: Optional[Callable] = None jit: bool = True @@ -554,7 +556,7 @@ def _check_dual_infeasability(self, error, status, delta_x, Q, c, Adx, l, u): certif_dual_infeasible = jnp.logical_and(jnp.logical_and(certif_Q <= criterion, certif_c <= criterion), certif_A) - if self.verbose >= 2: + if int(self.verbose) >= 2: jax.debug.print("certif_Q: {certif_Q} certif_c: {certif_c} certif_A: {certif_A} " "criterion: {criterion}, Adx: {Adx}, certif_l: {certif_l}, certif_u: {certif_u}", certif_Q=certif_Q, certif_c=certif_c, certif_A=certif_A, criterion=criterion, @@ -576,7 +578,7 @@ def _check_primal_infeasability(self, error, status, delta_y, ATdy, l, u): certif_lu = tree_add(tree_vdot(bounded_l, dy_minus), tree_vdot(bounded_u, dy_plus)) certif_primal_infeasible = jnp.logical_and(certif_A <= criterion, certif_lu <= criterion) - if self.verbose >= 2: + if int(self.verbose) >= 2: jax.debug.print("certif_A: {certif_A}, certif_lu: {certif_lu}, criterion: {criterion}", certif_A=certif_A, certif_lu=certif_lu, criterion=criterion) @@ -688,15 +690,15 @@ def update(self, # for active constraints (in particular equality constraints) high stepsize is better rho_bar = state.rho_bar - if self.verbose >= 2: + if int(self.verbose) >= 2: jax.debug.print("rho_bar: {rho_bar}", rho_bar=rho_bar) (x, z), y, solver_state = self._admm_step(params, Q, c, A, (l, u), rho_bar, state) - if self.verbose >= 3: + if int(self.verbose) >= 3: jax.debug.print("x: {x} z: {z} y: {y}", x=x, z=z, y=y) primal_residuals, dual_residuals = self._compute_residuals(Q, c, A, x, z, y) - if self.verbose >= 3: + if int(self.verbose) >= 3: jax.debug.print("primal_residuals: {primal_residuals}, dual_residuals: {dual_residuals}", primal_residuals=primal_residuals, dual_residuals=dual_residuals) diff --git a/jaxopt/_src/polyak_sgd.py b/jaxopt/_src/polyak_sgd.py index c8b19a12..2ba74630 100644 --- a/jaxopt/_src/polyak_sgd.py +++ b/jaxopt/_src/polyak_sgd.py @@ -18,6 +18,7 @@ from typing import Callable from typing import NamedTuple from typing import Optional +from typing import Union import dataclasses @@ -87,7 +88,7 @@ class PolyakSGD(base.StochasticSolver): maxiter: maximum number of solver iterations. tol: tolerance to use. - verbose: whether to print error on every iteration or not. + verbose: whether to print information on every iteration or not. implicit_diff: whether to enable implicit diff or autodiff of unrolled iterations. @@ -120,7 +121,7 @@ class PolyakSGD(base.StochasticSolver): maxiter: int = 500 tol: float = 1e-3 - verbose: int = 0 + verbose: Union[bool, int] = False implicit_diff: bool = False implicit_diff_solve: Optional[Callable] = None @@ -205,6 +206,16 @@ def update(self, aux=aux, num_fun_eval=state.num_fun_eval + 1, num_grad_eval=state.num_grad_eval + 1) + + if self.verbose: + self.log_info( + new_state, + error_name="Gradient Norm", + additional_info={ + "Objective Value": value, + "Stepsize": stepsize, + } + ) return base.OptStep(params=new_params, state=new_state) def optimality_fun(self, params, *args, **kwargs): diff --git a/jaxopt/_src/projected_gradient.py b/jaxopt/_src/projected_gradient.py index dfa92101..5093e547 100644 --- a/jaxopt/_src/projected_gradient.py +++ b/jaxopt/_src/projected_gradient.py @@ -61,7 +61,7 @@ class ProjectedGradient(base.IterativeSolver): tol: tolerance to use. acceleration: whether to use acceleration (also known as FISTA) or not. - verbose: whether to print error on every iteration or not. + verbose: whether to print information on every iteration or not. implicit_diff: whether to enable implicit diff or autodiff of unrolled iterations. @@ -85,7 +85,7 @@ class ProjectedGradient(base.IterativeSolver): acceleration: bool = True decrease_factor: float = 0.5 - verbose: int = 0 + verbose: Union[bool, int] = False implicit_diff: bool = True implicit_diff_solve: Optional[Callable] = None diff --git a/jaxopt/_src/proximal_gradient.py b/jaxopt/_src/proximal_gradient.py index 877691eb..76dba7e9 100644 --- a/jaxopt/_src/proximal_gradient.py +++ b/jaxopt/_src/proximal_gradient.py @@ -126,7 +126,7 @@ class ProximalGradient(base.IterativeSolver): tol: tolerance to use. acceleration: whether to use acceleration (also known as FISTA) or not. decrease_factor: factor by which to reduce the stepsize during line search. - verbose: whether to print error on every iteration or not. + verbose: whether to print information on every iteration or not. implicit_diff: whether to enable implicit diff or autodiff of unrolled iterations. @@ -153,7 +153,7 @@ class ProximalGradient(base.IterativeSolver): tol: float = 1e-3 acceleration: bool = True decrease_factor: float = 0.5 - verbose: int = 0 + verbose: Union[bool, int] = False implicit_diff: bool = True implicit_diff_solve: Optional[Callable] = None @@ -273,6 +273,15 @@ def _update_accel(self, x, state, hyperparams_prox, args, kwargs): stepsize=jnp.asarray(next_stepsize, dtype=dtype), error=jnp.asarray(next_error, dtype=dtype), aux=aux) + + if self.verbose: + self.log_info( + next_state, + error_name="Distance btw Iterates", + additional_info={ + "Stepsize": next_stepsize + } + ) return base.OptStep(params=next_x, state=next_state) def update(self, diff --git a/jaxopt/_src/zoom_linesearch.py b/jaxopt/_src/zoom_linesearch.py index 684eeef3..de732e67 100644 --- a/jaxopt/_src/zoom_linesearch.py +++ b/jaxopt/_src/zoom_linesearch.py @@ -224,8 +224,7 @@ class ZoomLineSearch(base.IterativeLineSearch): max_stepsize: maximal possible stepsize, (default: 2**maxiter) tol: tolerance of the stopping criterion. (default: 0.0) maxiter: maximum number of line search iterations. (default: 30) - verbose: whether to print error on every iteration or not. verbose=True will - automatically disable jit. (default: False) + verbose: whether to print information on every iteration or not. jit: whether to JIT-compile the optimization loop (default: "auto"). unroll: whether to unroll the optimization loop (default: "auto"). """ @@ -246,7 +245,7 @@ class ZoomLineSearch(base.IterativeLineSearch): maxiter: int = 30 max_stepsize: Optional[float] = None - verbose: bool = False + verbose: Union[bool, int] = False jit: base.AutoOrBoolean = "auto" unroll: base.AutoOrBoolean = "auto" @@ -287,14 +286,17 @@ def _curvature_error(self, slope_step, slope_init): def _make_safe_step(self, stepsize, state, args, kwargs): safe_stepsize = state.safe_stepsize + outside_domain = jnp.isinf(state.decrease_error) + final_stepsize = jnp.where((safe_stepsize > 0.) | outside_domain, safe_stepsize, stepsize) if self.verbose: - _cond_print((safe_stepsize > 0.), FLAG_CURVATURE_COND_NOT_SATSIFIED) - final_stepsize = jax.lax.cond( - safe_stepsize > 0., - lambda safe_stepsize, *_: safe_stepsize, - self.failure_diagnostic, - safe_stepsize, stepsize, state - ) + jax.lax.cond( + safe_stepsize > 0., + lambda *_: jax.debug.print(FLAG_CURVATURE_COND_NOT_SATSIFIED), + self.failure_diagnostic, + stepsize, + state + ) + step = tree_add_scalar_mul( state.params, final_stepsize, state.descent_direction ) @@ -335,8 +337,8 @@ def _search_interval(self, init_stepsize, state, args, kwargs): # Choose new point, larger than previous one or set to initial guess # for first iteration. larger_stepsize = self.increase_factor * prev_stepsize - new_stepsize_ = jnp.where(iter_num == 0, init_stepsize, larger_stepsize) - new_stepsize = jnp.minimum(new_stepsize_, self.max_stepsize) + new_stepsize = jnp.where(iter_num == 0, init_stepsize, larger_stepsize) + new_stepsize = jnp.minimum(new_stepsize, self.max_stepsize) max_stepsize_reached = new_stepsize >= self.max_stepsize new_value_step, new_slope_step, new_step, new_grad_step, new_aux_step = ( @@ -766,7 +768,7 @@ def update( del grad del descent_direction - best_stepsize_, new_state_ = cond( + new_stepsize, new_state = cond( state.interval_found, self._zoom_into_interval, self._search_interval, @@ -776,56 +778,55 @@ def update( fun_kwargs, jit=self.jit, ) - new_state_ = new_state_._replace( - num_fun_eval=new_state_.num_fun_eval + 1, - num_grad_eval=new_state_.num_grad_eval + 1, + new_state = new_state._replace( + num_fun_eval=new_state.num_fun_eval + 1, + num_grad_eval=new_state.num_grad_eval + 1, ) anticipated_num_func_grad_calls = jnp.array( - new_state_.failed + new_state.failed ).astype(base.NUM_EVAL_DTYPE) - best_stepsize, new_state = cond( - new_state_.failed, + new_stepsize, new_state = cond( + new_state.failed, self._make_safe_step, self._keep_step, - best_stepsize_, - new_state_, + new_stepsize, + new_state, fun_args, fun_kwargs, jit=self.jit, ) new_state = new_state._replace( - num_fun_eval=new_state_.num_fun_eval + anticipated_num_func_grad_calls, - num_grad_eval=new_state_.num_grad_eval + anticipated_num_func_grad_calls, + num_fun_eval=new_state.num_fun_eval + anticipated_num_func_grad_calls, + num_grad_eval=new_state.num_grad_eval + anticipated_num_func_grad_calls, ) - return base.LineSearchStep(stepsize=best_stepsize, state=new_state) + if self.verbose: + self._log_info(new_state, new_stepsize) + + return base.LineSearchStep(stepsize=new_stepsize, state=new_state) def _cond_fun(self, inputs): # Stop the linesearch according to done or failed rather than the error as one may # reach the maximal stepsize and no decrease of the curvature error may be # possible or the searched interval has been reduced too much. - stepsize, state = inputs[0] - if self.verbose: - self._log_info(stepsize, state) + _, state = inputs[0] return ~(state.done | state.failed) - def _log_info(self, stepsize, state): - jax.debug.print( - "INFO: jaxopt.ZoomLineSearch: " + \ - "Iter: {iter}, " + \ - "Stepsize: {stepsize}, " + \ - "Decrease error: {decrease_error}, " + \ - "Curvature error: {curvature_error}", - iter=state.iter_num, - stepsize=stepsize, - decrease_error=state.decrease_error, - curvature_error=state.curvature_error + def _log_info(self, state, stepsize): + self.log_info( + state, + error_name="Minimum Decrease & Curvature Errors", + additional_info={ + "Stepsize": stepsize, + "Decrease Error": state.decrease_error, + "Curvature Error": state.curvature_error + } ) - def failure_diagnostic(self, safe_stepsize, stepsize, state): + def failure_diagnostic(self, stepsize, state): jax.debug.print(FLAG_NO_STEPSIZE_FOUND) - self._log_info(stepsize, state) + self._log_info(state, stepsize) slope_init = state.slope_init is_descent_dir = slope_init < 0. @@ -876,8 +877,6 @@ def failure_diagnostic(self, safe_stepsize, stepsize, state): "Making an unsafe step, not decreasing enough the objective. " + \ "Convergence of the solver is compromised as it does not reduce values." ) - final_stepsize = jnp.where(outside_domain, safe_stepsize, stepsize) - return final_stepsize def __post_init__(self): self._fun_with_aux, _, self._value_and_grad_fun_with_aux = ( diff --git a/tests/common_test.py b/tests/common_test.py index d640532b..45d71f01 100644 --- a/tests/common_test.py +++ b/tests/common_test.py @@ -388,25 +388,25 @@ def fixed_point_fun(params): solvers = ( # Unconstrained - jaxopt.GradientDescent(fun=fun, jit=True, verbose=1, maxiter=4), - jaxopt.PolyakSGD(fun=fun, jit=True, verbose=1, maxiter=4), + jaxopt.GradientDescent(fun=fun, jit=True, verbose=True, maxiter=4), + jaxopt.PolyakSGD(fun=fun, jit=True, verbose=True, maxiter=4), jaxopt.Broyden(fun=fixed_point_fun, jit=True, verbose=True, maxiter=4), jaxopt.AndersonAcceleration(fixed_point_fun=fixed_point_fun, jit=True, verbose=True, maxiter=4), - jaxopt.ArmijoSGD(fun=fun, jit=True, verbose=1, maxiter=4), - jaxopt.BFGS(fun, linesearch="zoom", jit=True, verbose=1, maxiter=4), - jaxopt.BFGS(fun, linesearch="backtracking", jit=True, verbose=1, maxiter=4), - jaxopt.BFGS(fun, linesearch="hager-zhang", jit=True, verbose=1, maxiter=4), - jaxopt.LBFGS(fun=fun, jit=True, verbose=1, maxiter=4), - jaxopt.ArmijoSGD(fun=fun, jit=True, verbose=1, maxiter=4), - jaxopt.NonlinearCG(fun, jit=True, verbose=1, maxiter=4), + jaxopt.ArmijoSGD(fun=fun, jit=True, verbose=True, maxiter=4), + jaxopt.BFGS(fun, linesearch="zoom", jit=True, verbose=True, maxiter=4), + jaxopt.BFGS(fun, linesearch="backtracking", jit=True, verbose=True, maxiter=4), + jaxopt.BFGS(fun, linesearch="hager-zhang", jit=True, verbose=True, maxiter=4), + jaxopt.LBFGS(fun=fun, jit=True, verbose=True, maxiter=4), + jaxopt.ArmijoSGD(fun=fun, jit=True, verbose=True, maxiter=4), + jaxopt.NonlinearCG(fun, jit=True, verbose=True, maxiter=4), # Unconstrained, nonlinear least-squares jaxopt.GaussNewton(residual_fun=fun, jit=True, verbose=True, maxiter=4), jaxopt.LevenbergMarquardt(residual_fun=fun, jit=True, verbose=True, maxiter=4), # Constrained jaxopt.ProjectedGradient(fun=fun, - projection=jaxopt.projection.projection_non_negative, jit=True, verbose=1, maxiter=4), + projection=jaxopt.projection.projection_non_negative, jit=True, verbose=True, maxiter=4), # Optax wrapper - jaxopt.OptaxSolver(opt=optax.adam(1e-1), fun=fun, jit=True, verbose=1, maxiter=4), + jaxopt.OptaxSolver(opt=optax.adam(1e-1), fun=fun, jit=True, verbose=True, maxiter=4), ) @partial(jax.jit, static_argnums=(1,)) @@ -429,10 +429,10 @@ def run_solver_prox(p0, solver): return solver.run(p0, hyperparams_prox=1.0, data=data) for solver in (jaxopt.ProximalGradient(fun=fun, prox=prox.prox_lasso, - jit=True, verbose=1, maxiter=4), + jit=True, verbose=True, maxiter=4), jaxopt.BlockCoordinateDescent(fun=fun, block_prox=prox.prox_lasso, - jit=True, verbose=1, maxiter=4) + jit=True, verbose=True, maxiter=4) ): with redirect_stdout(io.StringIO()): run_solver_prox(params0, solver) @@ -465,7 +465,7 @@ def run_mirror_descent(b0): stepsize=1e-3, maxiter=4, jit=True, - verbose=1) + verbose=True) _, state = md.run(b0, None, lam, data) return state @@ -486,7 +486,7 @@ def run_mirror_descent(b0): @jax.jit def run_box_osqp(params_obj, params_ineq): - osqp = BoxOSQP(matvec_Q=matvec_Q, matvec_A=matvec_A, tol=tol, jit=True, verbose=1, maxiter=4) + osqp = BoxOSQP(matvec_Q=matvec_Q, matvec_A=matvec_A, tol=tol, jit=True, verbose=True, maxiter=4) return osqp.run(None, (None, params_obj), None, (params_ineq, params_ineq)) with redirect_stdout(io.StringIO()): diff --git a/tests/zoom_linesearch_test.py b/tests/zoom_linesearch_test.py index 7702e7b2..207fcd0a 100644 --- a/tests/zoom_linesearch_test.py +++ b/tests/zoom_linesearch_test.py @@ -232,7 +232,7 @@ def fun(x): # Test that the line search fails for p not a descent direction # For high maxiter, still finds a decrease error because of # the approximate Wolfe condition so we reduced maxiter - ls = ZoomLineSearch(fun, c2=0.5, maxiter=18) + ls = ZoomLineSearch(fun, c2=0.5, maxiter=18, verbose=True) stdout = io.StringIO() with redirect_stdout(stdout): s, state = ls.run(init_stepsize=1.0, params=x, descent_direction=p)