Skip to content

Commit

Permalink
added nonregression test
Browse files Browse the repository at this point in the history
  • Loading branch information
zaccharieramzi committed Oct 5, 2023
1 parent 08fd2f8 commit b5c7dd3
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions tests/scipy_wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,18 @@ def test_maxiter_api(self):
with self.assertRaises(ValueError):
ScipyMinimize(fun=self.logreg_fun, maxiter=500, options={'maxiter': 100})

@parameterized.product(method=[
'Nelder-Mead',
'Powell',
'TNC',
])
def test_no_njev(self, method):
# test methods that do not return njev
# see https://github.com/google/jaxopt/pull/542
solver = ScipyMinimize(fun=self.logreg_fun, method=method)
pytree_init = jnp.zeros([self.n_features, self.n_classes])
solver.run(pytree_init, l2reg=self.default_l2reg, data=self.data)

def test_callback(self):
# test the callback API
trace = []
Expand Down Expand Up @@ -388,7 +400,7 @@ def wrapper(b):
jac_idf = [list(blk_row) for blk_row in jac_idf]
jac_idf = jnp.block(jac_idf)
self.assertArraysAllClose(jac_theo, jac_idf, atol=1e-3)

def test_broyden(self):
# non regression test based on issue https://github.com/google/jaxopt/issues/290
def simple_root(x):
Expand Down Expand Up @@ -618,7 +630,7 @@ def setUp(self):

def test_inverse_hessian_bfgs(self):
# perform the bfgs algorithm search
res_scipy = osp.optimize.minimize(self.objective, self.x0,
res_scipy = osp.optimize.minimize(self.objective, self.x0,
method='BFGS', jac=self.derivative)
# run the same algorithm via jaxopt
res_jaxopt = ScipyMinimize(method='BFGS', fun=self.objective).run(self.x0)
Expand All @@ -627,7 +639,7 @@ def test_inverse_hessian_bfgs(self):

def test_inverse_hessian_lbfgs(self):
# perform the l-bfgs-b algorithm search
res_scipy = osp.optimize.minimize(self.objective, self.x0,
res_scipy = osp.optimize.minimize(self.objective, self.x0,
method='L-BFGS-B', jac=self.derivative)
# run the same algorithm via jaxopt
res_jaxopt = ScipyMinimize(method='L-BFGS-B', fun=self.objective).run(self.x0)
Expand Down

0 comments on commit b5c7dd3

Please sign in to comment.