Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 634381464
  • Loading branch information
Jake VanderPlas authored and Selforg Gardener committed May 16, 2024
1 parent bb32aa7 commit f188d30
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions notebooks/jax_raycast.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@
"\n",
"def render_frame(t):\n",
" t = t*t*(3-2*t) # easing\n",
" s = jax.tree_map(partial(cubic, t), s0, s1, s2, s0)\n",
" s = jax.tree.map(partial(cubic, t), s0, s1, s2, s0)\n",
" return render_scene(**s)\n",
"\n",
"animate(render_frame, 10)"
Expand Down Expand Up @@ -857,7 +857,7 @@
{
"cell_type": "markdown",
"source": [
"The code above uses `jax.tree_map` to conveniently interpolate between nested data structures that contain information about the scene, including light and camera position.\n",
"The code above uses `jax.tree.map` to conveniently interpolate between nested data structures that contain information about the scene, including light and camera position.\n",
"\n",
"Rendering a 10 second 60 fps animation takes from ~20 to less than 10 seconds depending on the available GPU (I tried V100 and A100, also running the first time in a Colab session takes a bit longer). Not bad for a ~100 lines of code engine built from scratch on top of the generic array processing language.\n",
"\n",
Expand Down

0 comments on commit f188d30

Please sign in to comment.