Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 634132029
  • Loading branch information
Jake VanderPlas authored and JAXopt authors committed May 16, 2024
1 parent 8da3350 commit b4723a8
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 19 deletions.
4 changes: 2 additions & 2 deletions docs/notebooks/distributed/custom_loop_pjit_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@
" in_axis_resources=PartitionSpec('data'),\n",
" out_axis_resources=PartitionSpec('data'))(*data)\n",
" else: # Just move data to device.\n",
" data = jax.tree_map(jax.device_put, data)\n",
" data = jax.tree.map(jax.device_put, data)\n",
"\n",
" # Pre-compiles update, preventing it from affecting step times.\n",
" tic = time.time()\n",
Expand All @@ -364,7 +364,7 @@
" for it in range(MAXITER):\n",
" tic = time.time()\n",
" params, state = update(params, state, data)\n",
" jax.tree_map(lambda t: t.block_until_ready(), (params, state))\n",
" jax.tree.map(lambda t: t.block_until_ready(), (params, state))\n",
" step_times[it] = time.time() - tic\n",
" errors[it] = state.error.item()\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/distributed/custom_loop_pjit_example.md
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def fit(
in_axis_resources=PartitionSpec('data'),
out_axis_resources=PartitionSpec('data'))(*data)
else: # Just move data to device.
data = jax.tree_map(jax.device_put, data)
data = jax.tree.map(jax.device_put, data)
# Pre-compiles update, preventing it from affecting step times.
tic = time.time()
Expand All @@ -285,7 +285,7 @@ def fit(
for it in range(MAXITER):
tic = time.time()
params, state = update(params, state, data)
jax.tree_map(lambda t: t.block_until_ready(), (params, state))
jax.tree.map(lambda t: t.block_until_ready(), (params, state))
step_times[it] = time.time() - tic
errors[it] = state.error.item()
Expand Down
8 changes: 4 additions & 4 deletions docs/notebooks/distributed/custom_loop_pmap_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@
" maybe_pmean = lambda t: jax.lax.pmean(t, axis_name) if t is not None else t\n",
" @functools.wraps(fun)\n",
" def wrapper(*args, **kwargs):\n",
" return jax.tree_map(maybe_pmean, fun(*args, **kwargs))\n",
" return jax.tree.map(maybe_pmean, fun(*args, **kwargs))\n",
" return wrapper"
]
},
Expand Down Expand Up @@ -321,9 +321,9 @@
" # occur in each update. This is true regardless of whether we use distributed\n",
" # or single-device computation.\n",
" if use_pmap: # Shards data and moves it to device,\n",
" data = jax.tree_map(shard_array, data)\n",
" data = jax.tree.map(shard_array, data)\n",
" else: # Just move data to device.\n",
" data = jax.tree_map(jax.device_put, data)\n",
" data = jax.tree.map(jax.device_put, data)\n",
"\n",
" # Pre-compiles update, preventing it from affecting step times.\n",
" tic = time.time()\n",
Expand All @@ -337,7 +337,7 @@
" for it in range(MAXITER):\n",
" tic = time.time()\n",
" params, state = update(params, state, data)\n",
" jax.tree_map(lambda t: t.block_until_ready(), (params, state))\n",
" jax.tree.map(lambda t: t.block_until_ready(), (params, state))\n",
" step_times[it] = time.time() - tic\n",
" errors[it] = (jax_utils.unreplicate(state.error).item()\n",
" if use_pmap else state.error.item())\n",
Expand Down
8 changes: 4 additions & 4 deletions docs/notebooks/distributed/custom_loop_pmap_example.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def pmean(fun: Callable[..., Any], axis_name: str = 'b') -> Callable[..., Any]:
maybe_pmean = lambda t: jax.lax.pmean(t, axis_name) if t is not None else t
@functools.wraps(fun)
def wrapper(*args, **kwargs):
return jax.tree_map(maybe_pmean, fun(*args, **kwargs))
return jax.tree.map(maybe_pmean, fun(*args, **kwargs))
return wrapper
```

Expand Down Expand Up @@ -235,9 +235,9 @@ def fit(
# occur in each update. This is true regardless of whether we use distributed
# or single-device computation.
if use_pmap: # Shards data and moves it to device,
data = jax.tree_map(shard_array, data)
data = jax.tree.map(shard_array, data)
else: # Just move data to device.
data = jax.tree_map(jax.device_put, data)
data = jax.tree.map(jax.device_put, data)
# Pre-compiles update, preventing it from affecting step times.
tic = time.time()
Expand All @@ -251,7 +251,7 @@ def fit(
for it in range(MAXITER):
tic = time.time()
params, state = update(params, state, data)
jax.tree_map(lambda t: t.block_until_ready(), (params, state))
jax.tree.map(lambda t: t.block_until_ready(), (params, state))
step_times[it] = time.time() - tic
errors[it] = (jax_utils.unreplicate(state.error).item()
if use_pmap else state.error.item())
Expand Down
6 changes: 3 additions & 3 deletions examples/deep_learning/distributed_flax_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def loss_fun(
mutable=['batch_stats'])

xentropy = cross_entropy_loss(labels=batch['label'], logits=logits)
weight_penalty_params = [x for x in jax.tree_leaves(params) if x.ndim > 1]
weight_penalty_params = [x for x in jax.tree.leaves(params) if x.ndim > 1]
weight_l2 = tree_util.tree_l2_norm(weight_penalty_params, squared=True)
loss = xentropy + weight_decay * 0.5 * weight_l2

Expand Down Expand Up @@ -647,9 +647,9 @@ def zeros_like_fun_output(
"""Replaces fun, outputting a pytree of zeroes with the original structure."""
def wrapper(*args, **kwargs):
pytree = jax.eval_shape(fun, *args, **kwargs)
leaves, treedef = jax.tree_flatten(pytree)
leaves, treedef = jax.tree.flatten(pytree)
leaves = [jnp.zeros(shape=leaf.shape, dtype=leaf.dtype) for leaf in leaves]
zeros_like_pytree = jax.tree_unflatten(treedef, leaves)
zeros_like_pytree = jax.tree.unflatten(treedef, leaves)
return zeros_like_pytree if index is None else zeros_like_pytree[index]
return wrapper

Expand Down
6 changes: 3 additions & 3 deletions jaxopt/_src/levenberg_marquardt.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,14 +296,14 @@ def gain_ratio_test_false_func(params, damping_factor,
gain_ratio_test_not_met_ret = gain_ratio_test_false_func(
*gain_ratio_test_init_state)

gain_ratio_test_is_met_ret = jax.tree_map(
gain_ratio_test_is_met_ret = jax.tree.map(
lambda x: gain_ratio_test_is_met * x, gain_ratio_test_is_met_ret)

gain_ratio_test_not_met_ret = jax.tree_map(
gain_ratio_test_not_met_ret = jax.tree.map(
lambda x: (1.0 - gain_ratio_test_is_met) * x,
gain_ratio_test_not_met_ret)

params, damping_factor, increase_factor, residual, gradient, jac, jt, jtj, hess_res, aux = jax.tree_map(
params, damping_factor, increase_factor, residual, gradient, jac, jt, jtj, hess_res, aux = jax.tree.map(
lambda x, y: x + y, gain_ratio_test_is_met_ret,
gain_ratio_test_not_met_ret)

Expand Down
2 changes: 1 addition & 1 deletion jaxopt/_src/optax_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def init_state(self,

def _apply_updates(self, params, updates):
update_fun = lambda p, u: jnp.asarray(p + u).astype(jnp.asarray(p).dtype)
return jax.tree_map(update_fun, params, updates)
return jax.tree.map(update_fun, params, updates)

def update(self,
params: Any,
Expand Down

0 comments on commit b4723a8

Please sign in to comment.