From 8da33502aa47341e48230c2a34cd6bb55df315ed Mon Sep 17 00:00:00 2001 From: Yue Sheng Date: Tue, 19 Mar 2024 14:38:05 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 617298469 --- tests/common_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/common_test.py b/tests/common_test.py index fe6e4b72..02698425 100644 --- a/tests/common_test.py +++ b/tests/common_test.py @@ -417,7 +417,7 @@ def run_solver(p0, solver): for solver in solvers: stdout = io.StringIO() with redirect_stdout(stdout): - run_solver(jnp.arange(2.), solver) + jax.block_until_ready(run_solver(jnp.arange(2.), solver)) printed = len(stdout.getvalue()) > 0 if verbose: self.assertTrue(printed) @@ -443,7 +443,7 @@ def run_solver_prox(p0, solver): ): stdout = io.StringIO() with redirect_stdout(stdout): - run_solver_prox(params0, solver) + jax.block_until_ready(run_solver_prox(params0, solver)) printed = len(stdout.getvalue()) > 0 if verbose: self.assertTrue(printed) @@ -484,7 +484,7 @@ def run_mirror_descent(b0): stdout = io.StringIO() with redirect_stdout(stdout): - run_mirror_descent(beta_init) + jax.block_until_ready(run_mirror_descent(beta_init)) printed = len(stdout.getvalue()) > 0 if verbose: self.assertTrue(printed) @@ -510,7 +510,7 @@ def run_box_osqp(params_obj, params_ineq): stdout = io.StringIO() with redirect_stdout(stdout): - run_box_osqp(q, b) + jax.block_until_ready(run_box_osqp(q, b)) printed = len(stdout.getvalue()) > 0 if verbose: self.assertTrue(printed)