Skip to content

Commit

Permalink
Explicitly wait until computations finish in the verbose test. Should…
Browse files Browse the repository at this point in the history
… be NFC since currently `jax.jit` is sync on CPU.

Context: we plan to make `jax.jit` async on CPU backend for expensive computations, which will fail this test without this change.
PiperOrigin-RevId: 616950120
  • Loading branch information
yueshengys authored and JAXopt authors committed Mar 18, 2024
1 parent 501cc20 commit 16fa573
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def run_solver(p0, solver):
for solver in solvers:
stdout = io.StringIO()
with redirect_stdout(stdout):
run_solver(jnp.arange(2.), solver)
jax.block_until_ready(run_solver(jnp.arange(2.), solver))
printed = len(stdout.getvalue()) > 0
if verbose:
self.assertTrue(printed)
Expand All @@ -443,7 +443,7 @@ def run_solver_prox(p0, solver):
):
stdout = io.StringIO()
with redirect_stdout(stdout):
run_solver_prox(params0, solver)
jax.block_until_ready(run_solver_prox(params0, solver))
printed = len(stdout.getvalue()) > 0
if verbose:
self.assertTrue(printed)
Expand Down Expand Up @@ -484,7 +484,7 @@ def run_mirror_descent(b0):

stdout = io.StringIO()
with redirect_stdout(stdout):
run_mirror_descent(beta_init)
jax.block_until_ready(run_mirror_descent(beta_init))
printed = len(stdout.getvalue()) > 0
if verbose:
self.assertTrue(printed)
Expand All @@ -510,7 +510,7 @@ def run_box_osqp(params_obj, params_ineq):

stdout = io.StringIO()
with redirect_stdout(stdout):
run_box_osqp(q, b)
jax.block_until_ready(run_box_osqp(q, b))
printed = len(stdout.getvalue()) > 0
if verbose:
self.assertTrue(printed)
Expand Down

0 comments on commit 16fa573

Please sign in to comment.