From 7afa914261b35fd21809b6b5ba69a904082bd4ba Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 23 Oct 2023 11:27:13 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 575878172 --- docs/notebooks/implicit_diff/maml.ipynb | 2 +- docs/notebooks/implicit_diff/maml.md | 2 +- jaxopt/_src/scipy_wrappers.py | 2 +- tests/anderson_wrapper_test.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/notebooks/implicit_diff/maml.ipynb b/docs/notebooks/implicit_diff/maml.ipynb index 8edf8b19..fdd9d9ac 100644 --- a/docs/notebooks/implicit_diff/maml.ipynb +++ b/docs/notebooks/implicit_diff/maml.ipynb @@ -73,7 +73,7 @@ "except (KeyError, RuntimeError):\n", " print(\"TPU not found, continuing without it.\")\n", "\n", - "from jax.config import config\n", + "from jax import config\n", "config.update(\"jax_enable_x64\", True)\n", "\n", "import jax\n", diff --git a/docs/notebooks/implicit_diff/maml.md b/docs/notebooks/implicit_diff/maml.md index 37eb9b20..98554d62 100644 --- a/docs/notebooks/implicit_diff/maml.md +++ b/docs/notebooks/implicit_diff/maml.md @@ -51,7 +51,7 @@ try: except (KeyError, RuntimeError): print("TPU not found, continuing without it.") -from jax.config import config +from jax import config config.update("jax_enable_x64", True) import jax diff --git a/jaxopt/_src/scipy_wrappers.py b/jaxopt/_src/scipy_wrappers.py index bba65419..55c3320d 100644 --- a/jaxopt/_src/scipy_wrappers.py +++ b/jaxopt/_src/scipy_wrappers.py @@ -34,7 +34,7 @@ from typing import Union import jax -from jax.config import config +from jax import config import jax.numpy as jnp import jax.tree_util as tree_util from jax.tree_util import register_pytree_node_class diff --git a/tests/anderson_wrapper_test.py b/tests/anderson_wrapper_test.py index 3df8c9de..f9a58fa6 100644 --- a/tests/anderson_wrapper_test.py +++ b/tests/anderson_wrapper_test.py @@ -18,7 +18,7 @@ import jax import jax.numpy as jnp -from jax.config import config +from jax import config from jax.tree_util import tree_map, tree_all from jax.test_util import check_grads import optax