Skip to content

Commit

Permalink
fix_print
Browse files Browse the repository at this point in the history
  • Loading branch information
vroulet committed Jan 24, 2024
1 parent 7b4dd31 commit 716011b
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 30 deletions.
2 changes: 1 addition & 1 deletion jaxopt/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def log_info(self, state, error_name='Error', additional_info={}):
jax.debug.print(
"INFO: jaxopt." + name + ": " + \
"Iter: {} " + \
error_name + " (stop. crit.): {} " + \
error_name + " (stopping criterion): {} " + \
other_info_kw,
state.iter_num,
state.error,
Expand Down
2 changes: 1 addition & 1 deletion jaxopt/_src/bfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def __post_init__(self):
max_stepsize=self.max_stepsize,
jit=self.jit,
unroll=unroll,
verbose=int(self.verbose)-1
verbose=max(int(self.verbose)-1, 0)
)
self.run_ls = self.linesearch_solver.run

Expand Down
2 changes: 1 addition & 1 deletion jaxopt/_src/broyden.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def ls_fun_with_aux(params, *args, **kwargs):
jit=self.jit,
unroll=self.unroll,
has_aux=True,
verbose=int(self.verbose)-1,
verbose=max(int(self.verbose)-1, 0),
tol=1e-2)
init_stepsize = jnp.where(state.stepsize <= self.min_stepsize,
# If stepsize became too small, we restart it.
Expand Down
2 changes: 1 addition & 1 deletion jaxopt/_src/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def __post_init__(self):
max_stepsize=self.max_stepsize,
jit=self.jit,
unroll=unroll,
verbose=self.verbose,
verbose=max(int(self.verbose)-1, 0),
)
self.run_ls = self.linesearch_solver.run

Expand Down
2 changes: 1 addition & 1 deletion jaxopt/_src/lbfgsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ def __post_init__(self):
max_stepsize=self.max_stepsize,
jit=self.jit,
unroll=unroll,
verbose=int(self.verbose)-1,
verbose=max(int(self.verbose)-1, 0),
)

self.run_ls = linesearch_solver.run
Expand Down
2 changes: 1 addition & 1 deletion jaxopt/_src/nonlinear_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def __post_init__(self):
max_stepsize=self.max_stepsize,
jit=self.jit,
unroll=unroll,
verbose=int(self.verbose)-1
verbose=max(int(self.verbose)-1, 0)
)

self.run_ls = linesearch_solver.run
Expand Down
3 changes: 3 additions & 0 deletions jaxopt/_src/osqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,9 @@ def update(self,
dual_residuals=dual_residuals,
rho_bar=rho_bar,
solver_state=solver_state)

if self.verbose:
self.log_info(state)
return base.OptStep(params=sol, state=state)

def run(self,
Expand Down
73 changes: 49 additions & 24 deletions tests/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,8 @@ def fun(w, X, y):
msg = "weak_type inconsistency for attribute '%s' in solver '%s'"
self.assertEqual(weak_type0, weak_type1, msg=msg % (field, solver_name))

def test_jit_with_verbose(self):
@parameterized.product(verbose=[True, False])
def test_jit_with_or_without_verbose(self, verbose):

fun = lambda p: p @ p

Expand All @@ -388,34 +389,40 @@ def fixed_point_fun(params):

solvers = (
# Unconstrained
jaxopt.GradientDescent(fun=fun, jit=True, verbose=True, maxiter=4),
jaxopt.PolyakSGD(fun=fun, jit=True, verbose=True, maxiter=4),
jaxopt.Broyden(fun=fixed_point_fun, jit=True, verbose=True, maxiter=4),
jaxopt.AndersonAcceleration(fixed_point_fun=fixed_point_fun, jit=True, verbose=True, maxiter=4),
jaxopt.ArmijoSGD(fun=fun, jit=True, verbose=True, maxiter=4),
jaxopt.BFGS(fun, linesearch="zoom", jit=True, verbose=True, maxiter=4),
jaxopt.BFGS(fun, linesearch="backtracking", jit=True, verbose=True, maxiter=4),
jaxopt.BFGS(fun, linesearch="hager-zhang", jit=True, verbose=True, maxiter=4),
jaxopt.LBFGS(fun=fun, jit=True, verbose=True, maxiter=4),
jaxopt.ArmijoSGD(fun=fun, jit=True, verbose=True, maxiter=4),
jaxopt.NonlinearCG(fun, jit=True, verbose=True, maxiter=4),
jaxopt.GradientDescent(fun=fun, jit=True, verbose=verbose, maxiter=4),
jaxopt.PolyakSGD(fun=fun, jit=True, verbose=verbose, maxiter=4),
jaxopt.Broyden(fun=fixed_point_fun, jit=True, verbose=verbose, maxiter=4),
jaxopt.AndersonAcceleration(fixed_point_fun=fixed_point_fun, jit=True, verbose=verbose, maxiter=4),
jaxopt.ArmijoSGD(fun=fun, jit=True, verbose=verbose, maxiter=4),
jaxopt.BFGS(fun, linesearch="zoom", jit=True, verbose=verbose, maxiter=4),
jaxopt.BFGS(fun, linesearch="backtracking", jit=True, verbose=verbose, maxiter=4),
jaxopt.BFGS(fun, linesearch="hager-zhang", jit=True, verbose=verbose, maxiter=4),
jaxopt.LBFGS(fun=fun, jit=True, verbose=verbose, maxiter=4),
jaxopt.ArmijoSGD(fun=fun, jit=True, verbose=verbose, maxiter=4),
jaxopt.NonlinearCG(fun, jit=True, verbose=verbose, maxiter=4),
# Unconstrained, nonlinear least-squares
jaxopt.GaussNewton(residual_fun=fun, jit=True, verbose=True, maxiter=4),
jaxopt.LevenbergMarquardt(residual_fun=fun, jit=True, verbose=True, maxiter=4),
jaxopt.GaussNewton(residual_fun=fun, jit=True, verbose=verbose, maxiter=4),
jaxopt.LevenbergMarquardt(residual_fun=fun, jit=True, verbose=verbose, maxiter=4),
# Constrained
jaxopt.ProjectedGradient(fun=fun,
projection=jaxopt.projection.projection_non_negative, jit=True, verbose=True, maxiter=4),
projection=jaxopt.projection.projection_non_negative, jit=True, verbose=verbose, maxiter=4),
# Optax wrapper
jaxopt.OptaxSolver(opt=optax.adam(1e-1), fun=fun, jit=True, verbose=True, maxiter=4),
jaxopt.OptaxSolver(opt=optax.adam(1e-1), fun=fun, jit=True, verbose=verbose, maxiter=4),
)

@partial(jax.jit, static_argnums=(1,))
def run_solver(p0, solver):
return solver.run(p0)

for solver in solvers:
with redirect_stdout(io.StringIO()):
stdout = io.StringIO()
with redirect_stdout(stdout):
run_solver(jnp.arange(2.), solver)
printed = len(stdout.getvalue()) > 0
if verbose:
self.assertTrue(printed)
else:
self.assertFalse(printed)

# Proximal gradient solvers
fun = objective.least_squares
Expand All @@ -429,13 +436,19 @@ def run_solver_prox(p0, solver):
return solver.run(p0, hyperparams_prox=1.0, data=data)

for solver in (jaxopt.ProximalGradient(fun=fun, prox=prox.prox_lasso,
jit=True, verbose=True, maxiter=4),
jit=True, verbose=verbose, maxiter=4),
jaxopt.BlockCoordinateDescent(fun=fun,
block_prox=prox.prox_lasso,
jit=True, verbose=True, maxiter=4)
jit=True, verbose=verbose, maxiter=4)
):
with redirect_stdout(io.StringIO()):
stdout = io.StringIO()
with redirect_stdout(stdout):
run_solver_prox(params0, solver)
printed = len(stdout.getvalue()) > 0
if verbose:
self.assertTrue(printed)
else:
self.assertFalse(printed)

# Mirror Descent
Y = preprocessing.LabelBinarizer().fit_transform(y)
Expand Down Expand Up @@ -465,12 +478,18 @@ def run_mirror_descent(b0):
stepsize=1e-3,
maxiter=4,
jit=True,
verbose=True)
verbose=verbose)
_, state = md.run(b0, None, lam, data)
return state

with redirect_stdout(io.StringIO()):
stdout = io.StringIO()
with redirect_stdout(stdout):
run_mirror_descent(beta_init)
printed = len(stdout.getvalue()) > 0
if verbose:
self.assertTrue(printed)
else:
self.assertFalse(printed)

# Quadratic programming - BoxOSQP
x = jnp.array([1.0, 2.0])
Expand All @@ -486,11 +505,17 @@ def run_mirror_descent(b0):

@jax.jit
def run_box_osqp(params_obj, params_ineq):
osqp = BoxOSQP(matvec_Q=matvec_Q, matvec_A=matvec_A, tol=tol, jit=True, verbose=True, maxiter=4)
osqp = BoxOSQP(matvec_Q=matvec_Q, matvec_A=matvec_A, tol=tol, jit=True, verbose=verbose, maxiter=4)
return osqp.run(None, (None, params_obj), None, (params_ineq, params_ineq))

with redirect_stdout(io.StringIO()):
stdout = io.StringIO()
with redirect_stdout(stdout):
run_box_osqp(q, b)
printed = len(stdout.getvalue()) > 0
if verbose:
self.assertTrue(printed)
else:
self.assertFalse(printed)


if __name__ == '__main__':
Expand Down

0 comments on commit 716011b

Please sign in to comment.