From 4016759551075a3f19a885daf1b2dd42185b2c68 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sun, 26 Jan 2025 01:30:51 +0000 Subject: [PATCH] enable setting flags via env vars like JAX_CPU_COLLECTIVES_IMPLEMENTATION=gloo --- jax/_src/config.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index b5656e34a9af..37376653a03f 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -778,6 +778,7 @@ def _set(self, value: _T) -> None: def bool_flag(name, default, *args, **kwargs) -> Flag[bool]: update_hook = kwargs.pop("update_hook", None) + default = bool(os.getenv(name.upper(), default)) holder = Flag(name, default, update_hook) config.add_option(name, holder, bool, args, kwargs) return holder @@ -785,6 +786,7 @@ def bool_flag(name, default, *args, **kwargs) -> Flag[bool]: def int_flag(name, default, *args, **kwargs) -> Flag[int]: update_hook = kwargs.pop("update_hook", None) + default = int(os.getenv(name.upper(), default)) holder = Flag(name, default, update_hook) config.add_option(name, holder, int, args, kwargs) return holder @@ -792,6 +794,7 @@ def int_flag(name, default, *args, **kwargs) -> Flag[int]: def float_flag(name, default, *args, **kwargs) -> Flag[float]: update_hook = kwargs.pop("update_hook", None) + default = float(os.getenv(name.upper(), default)) holder = Flag(name, default, update_hook) config.add_option(name, holder, float, args, kwargs) return holder @@ -799,15 +802,19 @@ def float_flag(name, default, *args, **kwargs) -> Flag[float]: def string_flag(name, default, *args, **kwargs) -> Flag[str]: update_hook = kwargs.pop("update_hook", None) + default = os.getenv(name.upper(), default) holder = Flag(name, default, update_hook) config.add_option(name, holder, str, args, kwargs) return holder -def enum_flag(name, default, *args, **kwargs) -> Flag[str]: +def enum_flag(name, default, enum_values, *args, **kwargs) -> Flag[str]: update_hook = kwargs.pop("update_hook", None) + default = os.getenv(name.upper(), default) + if default not in enum_values: + raise ValueError(f"Invalid value \"{default}\" for JAX flag {name}") holder = Flag(name, default, update_hook) - config.add_option(name, holder, 'enum', args, kwargs) + config.add_option(name, holder, 'enum', (enum_values, *args), kwargs) return holder