From f5a35f6d228abc29ba6e224c20ad97fc07dbd1be Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sun, 9 Jun 2024 08:54:01 -0700 Subject: [PATCH] [JAX] Update users of jax.tree.map() to be more careful about how they handle Nones. Due to a bug in JAX, JAX previously permitted `jax.tree.map(f, None, x)` where `x` is not `None`, effectively treating `None` as if it were pytree-prefix of any value. But `None` is a pytree container, and it is only a prefix of `None` itself. Fix user code that was relying on this bug. Most commonly, the fix is to write `jax.tree.map(lambda a, b: (None if a is None else f(a, b)), x, y, is_leaf=lambda t: t is None)`. PiperOrigin-RevId: 641687779 --- clrs/_src/baselines.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/clrs/_src/baselines.py b/clrs/_src/baselines.py index 97e30d7a..e481a921 100644 --- a/clrs/_src/baselines.py +++ b/clrs/_src/baselines.py @@ -766,7 +766,9 @@ def _keep_in_algo(k, v): masked_grads = grads else: masked_grads = {k: _keep_in_algo(k, v) for k, v in grads.items()} - flat_grads, treedef = jax.tree_util.tree_flatten(masked_grads) + flat_grads, treedef = jax.tree_util.tree_flatten( + masked_grads, is_leaf=lambda x: x is None + ) flat_opt_state = jax.tree_util.tree_map( lambda _, x: x # pylint:disable=g-long-lambda if isinstance(x, (np.ndarray, jax.Array))