Skip to content

Commit

Permalink
Stackless yashful
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681582933
  • Loading branch information
dougalm authored and KfacJaxDev committed Oct 9, 2024
1 parent d31f7ac commit d478f67
Showing 1 changed file with 1 addition and 9 deletions.
10 changes: 1 addition & 9 deletions kfac_jax/_src/utils/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,7 @@ def in_pmap(axis_name: str | None) -> bool:
if axis_name is None:
return False

try:
# The only way to know if we are under `jax.pmap` is to check if the
# function call below raises a `NameError` or not.
core.axis_frame(axis_name)

return True

except NameError:
return False
return axis_name in core.unsafe_get_axis_names_DO_NOT_USE()


def wrap_if_pmap(
Expand Down

0 comments on commit d478f67

Please sign in to comment.