From 10c92514594fb3c33d4b3ff206a2368223b2ada8 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 15 May 2024 17:28:01 -0700 Subject: [PATCH] Replace deprecated `jax.tree_*` functions with `jax.tree.*` The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25. PiperOrigin-RevId: 634132029 --- docs/notebooks/distributed/custom_loop_pjit_example.ipynb | 4 ++-- docs/notebooks/distributed/custom_loop_pjit_example.md | 4 ++-- docs/notebooks/distributed/custom_loop_pmap_example.ipynb | 8 ++++---- docs/notebooks/distributed/custom_loop_pmap_example.md | 8 ++++---- examples/deep_learning/distributed_flax_imagenet.py | 6 +++--- jaxopt/_src/levenberg_marquardt.py | 6 +++--- jaxopt/_src/optax_wrapper.py | 2 +- 7 files changed, 19 insertions(+), 19 deletions(-) diff --git a/docs/notebooks/distributed/custom_loop_pjit_example.ipynb b/docs/notebooks/distributed/custom_loop_pjit_example.ipynb index 1df42240..3f1a8cf9 100644 --- a/docs/notebooks/distributed/custom_loop_pjit_example.ipynb +++ b/docs/notebooks/distributed/custom_loop_pjit_example.ipynb @@ -350,7 +350,7 @@ " in_axis_resources=PartitionSpec('data'),\n", " out_axis_resources=PartitionSpec('data'))(*data)\n", " else: # Just move data to device.\n", - " data = jax.tree_map(jax.device_put, data)\n", + " data = jax.tree.map(jax.device_put, data)\n", "\n", " # Pre-compiles update, preventing it from affecting step times.\n", " tic = time.time()\n", @@ -364,7 +364,7 @@ " for it in range(MAXITER):\n", " tic = time.time()\n", " params, state = update(params, state, data)\n", - " jax.tree_map(lambda t: t.block_until_ready(), (params, state))\n", + " jax.tree.map(lambda t: t.block_until_ready(), (params, state))\n", " step_times[it] = time.time() - tic\n", " errors[it] = state.error.item()\n", "\n", diff --git a/docs/notebooks/distributed/custom_loop_pjit_example.md b/docs/notebooks/distributed/custom_loop_pjit_example.md index aa5c15cb..606ef9b1 100644 --- a/docs/notebooks/distributed/custom_loop_pjit_example.md +++ b/docs/notebooks/distributed/custom_loop_pjit_example.md @@ -271,7 +271,7 @@ def fit( in_axis_resources=PartitionSpec('data'), out_axis_resources=PartitionSpec('data'))(*data) else: # Just move data to device. - data = jax.tree_map(jax.device_put, data) + data = jax.tree.map(jax.device_put, data) # Pre-compiles update, preventing it from affecting step times. tic = time.time() @@ -285,7 +285,7 @@ def fit( for it in range(MAXITER): tic = time.time() params, state = update(params, state, data) - jax.tree_map(lambda t: t.block_until_ready(), (params, state)) + jax.tree.map(lambda t: t.block_until_ready(), (params, state)) step_times[it] = time.time() - tic errors[it] = state.error.item() diff --git a/docs/notebooks/distributed/custom_loop_pmap_example.ipynb b/docs/notebooks/distributed/custom_loop_pmap_example.ipynb index b642890e..0f6a7f92 100644 --- a/docs/notebooks/distributed/custom_loop_pmap_example.ipynb +++ b/docs/notebooks/distributed/custom_loop_pmap_example.ipynb @@ -215,7 +215,7 @@ " maybe_pmean = lambda t: jax.lax.pmean(t, axis_name) if t is not None else t\n", " @functools.wraps(fun)\n", " def wrapper(*args, **kwargs):\n", - " return jax.tree_map(maybe_pmean, fun(*args, **kwargs))\n", + " return jax.tree.map(maybe_pmean, fun(*args, **kwargs))\n", " return wrapper" ] }, @@ -321,9 +321,9 @@ " # occur in each update. This is true regardless of whether we use distributed\n", " # or single-device computation.\n", " if use_pmap: # Shards data and moves it to device,\n", - " data = jax.tree_map(shard_array, data)\n", + " data = jax.tree.map(shard_array, data)\n", " else: # Just move data to device.\n", - " data = jax.tree_map(jax.device_put, data)\n", + " data = jax.tree.map(jax.device_put, data)\n", "\n", " # Pre-compiles update, preventing it from affecting step times.\n", " tic = time.time()\n", @@ -337,7 +337,7 @@ " for it in range(MAXITER):\n", " tic = time.time()\n", " params, state = update(params, state, data)\n", - " jax.tree_map(lambda t: t.block_until_ready(), (params, state))\n", + " jax.tree.map(lambda t: t.block_until_ready(), (params, state))\n", " step_times[it] = time.time() - tic\n", " errors[it] = (jax_utils.unreplicate(state.error).item()\n", " if use_pmap else state.error.item())\n", diff --git a/docs/notebooks/distributed/custom_loop_pmap_example.md b/docs/notebooks/distributed/custom_loop_pmap_example.md index 1a8dee3f..d09952f4 100644 --- a/docs/notebooks/distributed/custom_loop_pmap_example.md +++ b/docs/notebooks/distributed/custom_loop_pmap_example.md @@ -149,7 +149,7 @@ def pmean(fun: Callable[..., Any], axis_name: str = 'b') -> Callable[..., Any]: maybe_pmean = lambda t: jax.lax.pmean(t, axis_name) if t is not None else t @functools.wraps(fun) def wrapper(*args, **kwargs): - return jax.tree_map(maybe_pmean, fun(*args, **kwargs)) + return jax.tree.map(maybe_pmean, fun(*args, **kwargs)) return wrapper ``` @@ -235,9 +235,9 @@ def fit( # occur in each update. This is true regardless of whether we use distributed # or single-device computation. if use_pmap: # Shards data and moves it to device, - data = jax.tree_map(shard_array, data) + data = jax.tree.map(shard_array, data) else: # Just move data to device. - data = jax.tree_map(jax.device_put, data) + data = jax.tree.map(jax.device_put, data) # Pre-compiles update, preventing it from affecting step times. tic = time.time() @@ -251,7 +251,7 @@ def fit( for it in range(MAXITER): tic = time.time() params, state = update(params, state, data) - jax.tree_map(lambda t: t.block_until_ready(), (params, state)) + jax.tree.map(lambda t: t.block_until_ready(), (params, state)) step_times[it] = time.time() - tic errors[it] = (jax_utils.unreplicate(state.error).item() if use_pmap else state.error.item()) diff --git a/examples/deep_learning/distributed_flax_imagenet.py b/examples/deep_learning/distributed_flax_imagenet.py index 373b1de2..1238cffe 100644 --- a/examples/deep_learning/distributed_flax_imagenet.py +++ b/examples/deep_learning/distributed_flax_imagenet.py @@ -489,7 +489,7 @@ def loss_fun( mutable=['batch_stats']) xentropy = cross_entropy_loss(labels=batch['label'], logits=logits) - weight_penalty_params = [x for x in jax.tree_leaves(params) if x.ndim > 1] + weight_penalty_params = [x for x in jax.tree.leaves(params) if x.ndim > 1] weight_l2 = tree_util.tree_l2_norm(weight_penalty_params, squared=True) loss = xentropy + weight_decay * 0.5 * weight_l2 @@ -647,9 +647,9 @@ def zeros_like_fun_output( """Replaces fun, outputting a pytree of zeroes with the original structure.""" def wrapper(*args, **kwargs): pytree = jax.eval_shape(fun, *args, **kwargs) - leaves, treedef = jax.tree_flatten(pytree) + leaves, treedef = jax.tree.flatten(pytree) leaves = [jnp.zeros(shape=leaf.shape, dtype=leaf.dtype) for leaf in leaves] - zeros_like_pytree = jax.tree_unflatten(treedef, leaves) + zeros_like_pytree = jax.tree.unflatten(treedef, leaves) return zeros_like_pytree if index is None else zeros_like_pytree[index] return wrapper diff --git a/jaxopt/_src/levenberg_marquardt.py b/jaxopt/_src/levenberg_marquardt.py index 1a3646ac..45f91d7e 100644 --- a/jaxopt/_src/levenberg_marquardt.py +++ b/jaxopt/_src/levenberg_marquardt.py @@ -296,14 +296,14 @@ def gain_ratio_test_false_func(params, damping_factor, gain_ratio_test_not_met_ret = gain_ratio_test_false_func( *gain_ratio_test_init_state) - gain_ratio_test_is_met_ret = jax.tree_map( + gain_ratio_test_is_met_ret = jax.tree.map( lambda x: gain_ratio_test_is_met * x, gain_ratio_test_is_met_ret) - gain_ratio_test_not_met_ret = jax.tree_map( + gain_ratio_test_not_met_ret = jax.tree.map( lambda x: (1.0 - gain_ratio_test_is_met) * x, gain_ratio_test_not_met_ret) - params, damping_factor, increase_factor, residual, gradient, jac, jt, jtj, hess_res, aux = jax.tree_map( + params, damping_factor, increase_factor, residual, gradient, jac, jt, jtj, hess_res, aux = jax.tree.map( lambda x, y: x + y, gain_ratio_test_is_met_ret, gain_ratio_test_not_met_ret) diff --git a/jaxopt/_src/optax_wrapper.py b/jaxopt/_src/optax_wrapper.py index b17f4a66..e8c69b92 100644 --- a/jaxopt/_src/optax_wrapper.py +++ b/jaxopt/_src/optax_wrapper.py @@ -117,7 +117,7 @@ def init_state(self, def _apply_updates(self, params, updates): update_fun = lambda p, u: jnp.asarray(p + u).astype(jnp.asarray(p).dtype) - return jax.tree_map(update_fun, params, updates) + return jax.tree.map(update_fun, params, updates) def update(self, params: Any,