diff --git a/axlearn/common/checkpointer_orbax.py b/axlearn/common/checkpointer_orbax.py index 26c84d57..affe69cb 100644 --- a/axlearn/common/checkpointer_orbax.py +++ b/axlearn/common/checkpointer_orbax.py @@ -226,6 +226,12 @@ def save_fn_with_summaries(step: int, last_saved_step: Optional[int]) -> bool: step_prefix=STEP_PREFIX, step_format_fixed_length=STEP_NUM_DIGITS, ) + # TODO(matthew_e_hopkins): bring back save_concurrent_gb and restore_concurrent_gb + # after bumping up the Jax version. + if cfg.max_concurrent_restore_gb is not None: + raise NotImplementedError( + "Orbax version (0.5.23) doesn't support separate save/restore concurrent_gb." + ) self._manager = ocp.CheckpointManager( directory=cfg.dir, options=ocp.CheckpointManagerOptions( @@ -245,8 +251,7 @@ def save_fn_with_summaries(step: int, last_saved_step: Optional[int]) -> bool: # Note that this defaults to use_ocdb=True. Note also that custom `TypeHandler`s are # ignored by `StandardCheckpointHandler`, so we use `PyTreeCheckpointHandler`. "state": ocp.PyTreeCheckpointHandler( - save_concurrent_gb=cfg.max_concurrent_save_gb, - restore_concurrent_gb=cfg.max_concurrent_restore_gb, + concurrent_gb=cfg.max_concurrent_save_gb, ), }, ) diff --git a/axlearn/common/checkpointer_test.py b/axlearn/common/checkpointer_test.py index 7e348530..de8f81d2 100644 --- a/axlearn/common/checkpointer_test.py +++ b/axlearn/common/checkpointer_test.py @@ -121,20 +121,33 @@ def test_save_and_restore(self, checkpointer_cls: Type[BaseCheckpointer]): x=jnp.zeros([], dtype=jnp.int32), y=jnp.ones([3], dtype=jnp.float32) ), ) - - # When the given state has a different dict shape: [1] instead of [] for x. - # Orbax throws AssertionError in this case. - with self.assertRaisesRegex( - (AssertionError, ValueError), - "(checkpoint tree dtypes or shapes|not compatible)", - ): - ckpt.restore( - step=None, - state=dict( - x=jnp.zeros([1], dtype=jnp.int32), - y=jnp.ones([2], dtype=jnp.float32), - ), - ) + # TODO(matthew_e_hopkins): revert it once upgrade jax version. + if checkpointer_cls is Checkpointer: + # When the given state has a different dict shape: [1] instead of [] for x. + # Orbax throws AssertionError in this case. + with self.assertRaisesRegex( + (AssertionError, ValueError), + "(checkpoint tree dtypes or shapes|not compatible)", + ): + ckpt.restore( + step=None, + state=dict( + x=jnp.zeros([1], dtype=jnp.int32), + y=jnp.ones([2], dtype=jnp.float32), + ), + ) + else: + with self.assertRaisesRegex( + (AssertionError, ValueError), + "Cannot intersect index domain", + ): + ckpt.restore( + step=None, + state=dict( + x=jnp.zeros([1], dtype=jnp.int32), + y=jnp.ones([2], dtype=jnp.float32), + ), + ) # When the given state has a different dtype: float32 instead of int32 for x. with self.assertRaisesRegex(ValueError, "checkpoint tree dtypes or shapes"): diff --git a/pyproject.toml b/pyproject.toml index d0bc333a..35610db0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,8 +24,8 @@ core = [ "chex==0.1.86", # chex 0.1.86 is required for jax 0.4.25. "einops==0.8.0", "importlab==0.7", # breaks pytype on 0.8 - "jax==0.4.34", - "jaxlib==0.4.34", + "jax==0.4.33", + "jaxlib==0.4.33", "nltk==3.7", # for text preprocessing "optax==0.1.7", # optimizers (0.1.0 has known bugs). "portpicker", @@ -126,7 +126,7 @@ dataflow = [ # GPU custom kernel dependency. gpu = [ "triton==2.1.0", - "jax[cuda12]==0.4.34", + "jax[cuda12]==0.4.33", ] # Open API inference. open_api = [ @@ -146,7 +146,7 @@ mmau = [ # Orbax checkpointing. orbax = [ "humanize==4.10.0", - "orbax-checkpoint==0.9.1", + "orbax-checkpoint==0.5.23", ] # Grain input processing. Currently does not support macos. grain = [