From c818497c55f15fde0d2acf47b507f61cac7ed4ae Mon Sep 17 00:00:00 2001 From: Vincent Roulet Date: Wed, 24 Jan 2024 09:14:36 +0100 Subject: [PATCH] fix_print --- jaxopt/_src/base.py | 2 +- jaxopt/_src/bfgs.py | 2 +- jaxopt/_src/broyden.py | 2 +- jaxopt/_src/lbfgs.py | 2 +- jaxopt/_src/lbfgsb.py | 2 +- jaxopt/_src/nonlinear_cg.py | 2 +- jaxopt/_src/osqp.py | 3 ++ tests/common_test.py | 73 +++++++++++++++++++++++++------------ 8 files changed, 58 insertions(+), 30 deletions(-) diff --git a/jaxopt/_src/base.py b/jaxopt/_src/base.py index e3b7a507..5d0a1aa3 100644 --- a/jaxopt/_src/base.py +++ b/jaxopt/_src/base.py @@ -273,7 +273,7 @@ def log_info(self, state, error_name='Error', additional_info={}): jax.debug.print( "INFO: jaxopt." + name + ": " + \ "Iter: {} " + \ - error_name + " (stop. crit.): {} " + \ + error_name + " (stopping criterion): {} " + \ other_info_kw, state.iter_num, state.error, diff --git a/jaxopt/_src/bfgs.py b/jaxopt/_src/bfgs.py index d7b6939b..29780d03 100644 --- a/jaxopt/_src/bfgs.py +++ b/jaxopt/_src/bfgs.py @@ -301,7 +301,7 @@ def __post_init__(self): max_stepsize=self.max_stepsize, jit=self.jit, unroll=unroll, - verbose=int(self.verbose)-1 + verbose=max(int(self.verbose)-1, 0) ) self.run_ls = self.linesearch_solver.run diff --git a/jaxopt/_src/broyden.py b/jaxopt/_src/broyden.py index df968319..96708139 100644 --- a/jaxopt/_src/broyden.py +++ b/jaxopt/_src/broyden.py @@ -328,7 +328,7 @@ def ls_fun_with_aux(params, *args, **kwargs): jit=self.jit, unroll=self.unroll, has_aux=True, - verbose=int(self.verbose)-1, + verbose=max(int(self.verbose)-1, 0), tol=1e-2) init_stepsize = jnp.where(state.stepsize <= self.min_stepsize, # If stepsize became too small, we restart it. diff --git a/jaxopt/_src/lbfgs.py b/jaxopt/_src/lbfgs.py index 1be37e1b..d24324b9 100644 --- a/jaxopt/_src/lbfgs.py +++ b/jaxopt/_src/lbfgs.py @@ -451,7 +451,7 @@ def __post_init__(self): max_stepsize=self.max_stepsize, jit=self.jit, unroll=unroll, - verbose=self.verbose, + verbose=max(int(self.verbose)-1, 0), ) self.run_ls = self.linesearch_solver.run diff --git a/jaxopt/_src/lbfgsb.py b/jaxopt/_src/lbfgsb.py index 030e7082..44f5cf01 100644 --- a/jaxopt/_src/lbfgsb.py +++ b/jaxopt/_src/lbfgsb.py @@ -618,7 +618,7 @@ def __post_init__(self): max_stepsize=self.max_stepsize, jit=self.jit, unroll=unroll, - verbose=int(self.verbose)-1, + verbose=max(int(self.verbose)-1, 0), ) self.run_ls = linesearch_solver.run diff --git a/jaxopt/_src/nonlinear_cg.py b/jaxopt/_src/nonlinear_cg.py index 398f3db2..eff0a797 100644 --- a/jaxopt/_src/nonlinear_cg.py +++ b/jaxopt/_src/nonlinear_cg.py @@ -314,7 +314,7 @@ def __post_init__(self): max_stepsize=self.max_stepsize, jit=self.jit, unroll=unroll, - verbose=int(self.verbose)-1 + verbose=max(int(self.verbose)-1, 0) ) self.run_ls = linesearch_solver.run diff --git a/jaxopt/_src/osqp.py b/jaxopt/_src/osqp.py index a62c7f63..7eae7f86 100644 --- a/jaxopt/_src/osqp.py +++ b/jaxopt/_src/osqp.py @@ -737,6 +737,9 @@ def update(self, dual_residuals=dual_residuals, rho_bar=rho_bar, solver_state=solver_state) + + if self.verbose: + self.log_info(state) return base.OptStep(params=sol, state=state) def run(self, diff --git a/tests/common_test.py b/tests/common_test.py index 45d71f01..fe6e4b72 100644 --- a/tests/common_test.py +++ b/tests/common_test.py @@ -379,7 +379,8 @@ def fun(w, X, y): msg = "weak_type inconsistency for attribute '%s' in solver '%s'" self.assertEqual(weak_type0, weak_type1, msg=msg % (field, solver_name)) - def test_jit_with_verbose(self): + @parameterized.product(verbose=[True, False]) + def test_jit_with_or_without_verbose(self, verbose): fun = lambda p: p @ p @@ -388,25 +389,25 @@ def fixed_point_fun(params): solvers = ( # Unconstrained - jaxopt.GradientDescent(fun=fun, jit=True, verbose=True, maxiter=4), - jaxopt.PolyakSGD(fun=fun, jit=True, verbose=True, maxiter=4), - jaxopt.Broyden(fun=fixed_point_fun, jit=True, verbose=True, maxiter=4), - jaxopt.AndersonAcceleration(fixed_point_fun=fixed_point_fun, jit=True, verbose=True, maxiter=4), - jaxopt.ArmijoSGD(fun=fun, jit=True, verbose=True, maxiter=4), - jaxopt.BFGS(fun, linesearch="zoom", jit=True, verbose=True, maxiter=4), - jaxopt.BFGS(fun, linesearch="backtracking", jit=True, verbose=True, maxiter=4), - jaxopt.BFGS(fun, linesearch="hager-zhang", jit=True, verbose=True, maxiter=4), - jaxopt.LBFGS(fun=fun, jit=True, verbose=True, maxiter=4), - jaxopt.ArmijoSGD(fun=fun, jit=True, verbose=True, maxiter=4), - jaxopt.NonlinearCG(fun, jit=True, verbose=True, maxiter=4), + jaxopt.GradientDescent(fun=fun, jit=True, verbose=verbose, maxiter=4), + jaxopt.PolyakSGD(fun=fun, jit=True, verbose=verbose, maxiter=4), + jaxopt.Broyden(fun=fixed_point_fun, jit=True, verbose=verbose, maxiter=4), + jaxopt.AndersonAcceleration(fixed_point_fun=fixed_point_fun, jit=True, verbose=verbose, maxiter=4), + jaxopt.ArmijoSGD(fun=fun, jit=True, verbose=verbose, maxiter=4), + jaxopt.BFGS(fun, linesearch="zoom", jit=True, verbose=verbose, maxiter=4), + jaxopt.BFGS(fun, linesearch="backtracking", jit=True, verbose=verbose, maxiter=4), + jaxopt.BFGS(fun, linesearch="hager-zhang", jit=True, verbose=verbose, maxiter=4), + jaxopt.LBFGS(fun=fun, jit=True, verbose=verbose, maxiter=4), + jaxopt.ArmijoSGD(fun=fun, jit=True, verbose=verbose, maxiter=4), + jaxopt.NonlinearCG(fun, jit=True, verbose=verbose, maxiter=4), # Unconstrained, nonlinear least-squares - jaxopt.GaussNewton(residual_fun=fun, jit=True, verbose=True, maxiter=4), - jaxopt.LevenbergMarquardt(residual_fun=fun, jit=True, verbose=True, maxiter=4), + jaxopt.GaussNewton(residual_fun=fun, jit=True, verbose=verbose, maxiter=4), + jaxopt.LevenbergMarquardt(residual_fun=fun, jit=True, verbose=verbose, maxiter=4), # Constrained jaxopt.ProjectedGradient(fun=fun, - projection=jaxopt.projection.projection_non_negative, jit=True, verbose=True, maxiter=4), + projection=jaxopt.projection.projection_non_negative, jit=True, verbose=verbose, maxiter=4), # Optax wrapper - jaxopt.OptaxSolver(opt=optax.adam(1e-1), fun=fun, jit=True, verbose=True, maxiter=4), + jaxopt.OptaxSolver(opt=optax.adam(1e-1), fun=fun, jit=True, verbose=verbose, maxiter=4), ) @partial(jax.jit, static_argnums=(1,)) @@ -414,8 +415,14 @@ def run_solver(p0, solver): return solver.run(p0) for solver in solvers: - with redirect_stdout(io.StringIO()): + stdout = io.StringIO() + with redirect_stdout(stdout): run_solver(jnp.arange(2.), solver) + printed = len(stdout.getvalue()) > 0 + if verbose: + self.assertTrue(printed) + else: + self.assertFalse(printed) # Proximal gradient solvers fun = objective.least_squares @@ -429,13 +436,19 @@ def run_solver_prox(p0, solver): return solver.run(p0, hyperparams_prox=1.0, data=data) for solver in (jaxopt.ProximalGradient(fun=fun, prox=prox.prox_lasso, - jit=True, verbose=True, maxiter=4), + jit=True, verbose=verbose, maxiter=4), jaxopt.BlockCoordinateDescent(fun=fun, block_prox=prox.prox_lasso, - jit=True, verbose=True, maxiter=4) + jit=True, verbose=verbose, maxiter=4) ): - with redirect_stdout(io.StringIO()): + stdout = io.StringIO() + with redirect_stdout(stdout): run_solver_prox(params0, solver) + printed = len(stdout.getvalue()) > 0 + if verbose: + self.assertTrue(printed) + else: + self.assertFalse(printed) # Mirror Descent Y = preprocessing.LabelBinarizer().fit_transform(y) @@ -465,12 +478,18 @@ def run_mirror_descent(b0): stepsize=1e-3, maxiter=4, jit=True, - verbose=True) + verbose=verbose) _, state = md.run(b0, None, lam, data) return state - with redirect_stdout(io.StringIO()): + stdout = io.StringIO() + with redirect_stdout(stdout): run_mirror_descent(beta_init) + printed = len(stdout.getvalue()) > 0 + if verbose: + self.assertTrue(printed) + else: + self.assertFalse(printed) # Quadratic programming - BoxOSQP x = jnp.array([1.0, 2.0]) @@ -486,11 +505,17 @@ def run_mirror_descent(b0): @jax.jit def run_box_osqp(params_obj, params_ineq): - osqp = BoxOSQP(matvec_Q=matvec_Q, matvec_A=matvec_A, tol=tol, jit=True, verbose=True, maxiter=4) + osqp = BoxOSQP(matvec_Q=matvec_Q, matvec_A=matvec_A, tol=tol, jit=True, verbose=verbose, maxiter=4) return osqp.run(None, (None, params_obj), None, (params_ineq, params_ineq)) - with redirect_stdout(io.StringIO()): + stdout = io.StringIO() + with redirect_stdout(stdout): run_box_osqp(q, b) + printed = len(stdout.getvalue()) > 0 + if verbose: + self.assertTrue(printed) + else: + self.assertFalse(printed) if __name__ == '__main__':