Skip to content

Commit

Permalink
snapshot (#854)
Browse files Browse the repository at this point in the history
  • Loading branch information
kelvin-zou authored Nov 22, 2024
1 parent 1af2ba8 commit d0edead
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 20 deletions.
9 changes: 7 additions & 2 deletions axlearn/common/checkpointer_orbax.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,12 @@ def save_fn_with_summaries(step: int, last_saved_step: Optional[int]) -> bool:
step_prefix=STEP_PREFIX,
step_format_fixed_length=STEP_NUM_DIGITS,
)
# TODO(matthew_e_hopkins): bring back save_concurrent_gb and restore_concurrent_gb
# after bumping up the Jax version.
if cfg.max_concurrent_restore_gb is not None:
raise NotImplementedError(
"Orbax version (0.5.23) doesn't support separate save/restore concurrent_gb."
)
self._manager = ocp.CheckpointManager(
directory=cfg.dir,
options=ocp.CheckpointManagerOptions(
Expand All @@ -245,8 +251,7 @@ def save_fn_with_summaries(step: int, last_saved_step: Optional[int]) -> bool:
# Note that this defaults to use_ocdb=True. Note also that custom `TypeHandler`s are
# ignored by `StandardCheckpointHandler`, so we use `PyTreeCheckpointHandler`.
"state": ocp.PyTreeCheckpointHandler(
save_concurrent_gb=cfg.max_concurrent_save_gb,
restore_concurrent_gb=cfg.max_concurrent_restore_gb,
concurrent_gb=cfg.max_concurrent_save_gb,
),
},
)
Expand Down
41 changes: 27 additions & 14 deletions axlearn/common/checkpointer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,20 +121,33 @@ def test_save_and_restore(self, checkpointer_cls: Type[BaseCheckpointer]):
x=jnp.zeros([], dtype=jnp.int32), y=jnp.ones([3], dtype=jnp.float32)
),
)

# When the given state has a different dict shape: [1] instead of [] for x.
# Orbax throws AssertionError in this case.
with self.assertRaisesRegex(
(AssertionError, ValueError),
"(checkpoint tree dtypes or shapes|not compatible)",
):
ckpt.restore(
step=None,
state=dict(
x=jnp.zeros([1], dtype=jnp.int32),
y=jnp.ones([2], dtype=jnp.float32),
),
)
# TODO(matthew_e_hopkins): revert it once upgrade jax version.
if checkpointer_cls is Checkpointer:
# When the given state has a different dict shape: [1] instead of [] for x.
# Orbax throws AssertionError in this case.
with self.assertRaisesRegex(
(AssertionError, ValueError),
"(checkpoint tree dtypes or shapes|not compatible)",
):
ckpt.restore(
step=None,
state=dict(
x=jnp.zeros([1], dtype=jnp.int32),
y=jnp.ones([2], dtype=jnp.float32),
),
)
else:
with self.assertRaisesRegex(
(AssertionError, ValueError),
"Cannot intersect index domain",
):
ckpt.restore(
step=None,
state=dict(
x=jnp.zeros([1], dtype=jnp.int32),
y=jnp.ones([2], dtype=jnp.float32),
),
)

# When the given state has a different dtype: float32 instead of int32 for x.
with self.assertRaisesRegex(ValueError, "checkpoint tree dtypes or shapes"):
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ core = [
"chex==0.1.86", # chex 0.1.86 is required for jax 0.4.25.
"einops==0.8.0",
"importlab==0.7", # breaks pytype on 0.8
"jax==0.4.34",
"jaxlib==0.4.34",
"jax==0.4.33",
"jaxlib==0.4.33",
"nltk==3.7", # for text preprocessing
"optax==0.1.7", # optimizers (0.1.0 has known bugs).
"portpicker",
Expand Down Expand Up @@ -126,7 +126,7 @@ dataflow = [
# GPU custom kernel dependency.
gpu = [
"triton==2.1.0",
"jax[cuda12]==0.4.34",
"jax[cuda12]==0.4.33",
]
# Open API inference.
open_api = [
Expand All @@ -146,7 +146,7 @@ mmau = [
# Orbax checkpointing.
orbax = [
"humanize==4.10.0",
"orbax-checkpoint==0.9.1",
"orbax-checkpoint==0.5.23",
]
# Grain input processing. Currently does not support macos.
grain = [
Expand Down

0 comments on commit d0edead

Please sign in to comment.