Skip to content

Commit

Permalink
Initial commit for CheckpointManager (#5678)
Browse files Browse the repository at this point in the history
* Initial commit for CheckpointManager

* Update documentation for async

* Fix typo
  • Loading branch information
jonb377 authored Oct 10, 2023
1 parent d9a9049 commit 4524543
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 0 deletions.
2 changes: 2 additions & 0 deletions torch_xla/experimental/distributed_checkpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from .manager import CheckpointManager
from .planners import SPMDSavePlanner, SPMDLoadPlanner

__all__ = [
"CheckpointManager",
"SPMDSavePlanner",
"SPMDLoadPlanner",
]
148 changes: 148 additions & 0 deletions torch_xla/experimental/distributed_checkpoint/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import torch.distributed.checkpoint as dist_cp
import torch_xla.experimental.distributed_checkpoint as xc

from typing import List, Optional
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE


class CheckpointManager:
"""
The CheckpointManager class provides a higher-level wrapper around the
torch.distributed.checkpoint APIs to manage checkpointing. It builds on top
of those APIs to enable a few key features:
- Per-step checkpointing: Each checkpoint taken by the CheckpointManager is
identified by the step at which it was taken, and any step tracked
by the CheckpointManager can be restored.
- Async checkpointing: The torch.distributed.checkpoint APIs are
synchronous, which will block training for the duration of the
checkpoint. The CheckpointManager's save_async method can be used to
offload checkpointing to a background thread, unblocking training
while the checkpoint is written to persistent storage.
- Automatic checkpointing: If the training process would be shut down due
to a SIGTERM, the CheckpointManager will automatically take a
checkpoint at the next step.
- Native fsspec integration: Any storage protocol compatible with fsspec
can be used with CheckpointManager.
The intended usage of CheckpointManager is as follows:
>>> # Create a CheckpointManager to checkpoint every 10 steps into GCS.
>>> chkpt_mgr = CheckpointManager('gs://my-bucket/my-experiment', 10)
>>> # Select a checkpoint to restore from, and restore if applicable
>>> tracked_steps = chkpt_mgr.all_steps()
>>> if tracked_steps:
>>> # Choose the highest step
>>> best_step = max(tracked_steps)
>>> state_dict = {'model': model.state_dict()}
>>> chkpt_mgr.restore(best_step, state_dict)
>>> model.load_state_dict(state_dict['model'])
>>> # Call `save` or `save_async` every step within the train loop.
>>> for step, data in enumerate(dataloader):
>>> ...
>>> state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()}
>>> if chkpt_mgr.save_async(step, state_dict):
>>> print(f'Checkpoint taken at step {step}')
By calling `save` or `save_async` every step, the CheckpointManager has the
opportunity to take a checkpoint on steps which are out-of-cycle with its
step_period, as would be the case in auto checkpointing.
This class is inspired by Orbax's CheckpointManager, which can be found here:
https://github.com/google/orbax/blob/efc079c4e5b437782a80138913d322cb3ed365c7/checkpoint/orbax/checkpoint/checkpoint_manager.py
"""

def __init__(self,
path: str,
save_period: int,
max_to_keep: Optional[int] = -1,
async_queue_size: Optional[int] = 1):
"""
Create a checkpoint manager that reads and writes checkpoints into
the provided directory.
Args:
path: The base path for the CheckpointManager to write checkpoints into.
save_period: The number of steps between saving checkpoints.
max_to_keep: The maximum number of checkpoints to be tracked by the
CheckpointManager. When a new checkpoint will be taken, the
checkpoint for the lowest tracked step will be deleted.
Default: -1, indicating no upper bound on the number of checkpoints.
async_queue_size: The size of the execution queue which processes async
checkpoints. This should be a small value to ensure training doesn't
get too far ahead of the last finished checkpoint, but increasing
the value to 2 can unblock training when there are transient
network issues which slow down the active checkpoint.
Default: 1, which only allows a single async checkpoint to be
pending at a time.
"""
raise NotImplementedError

def should_save(self, step: int) -> bool:
"""
Returns true if a checkpoint should be saved for the current step or if
a preemption has been detected.
"""
raise NotImplementedError

def save(self,
step,
state_dict: STATE_DICT_TYPE,
force: Optional[bool] = False) -> bool:
"""
Take a checkpoint synchronously if `self.should_save(step)`.
Args:
step: The current training step.
state_dict: The state dict to be checkpointed.
force: Option to force a checkpoint to be taken regardless of the result
of `should_save(step)`.
Returns:
True if a checkpoint was taken and False otherwise.
"""
raise NotImplementedError

def save_async(self,
step: int,
state_dict: STATE_DICT_TYPE,
force: Optional[bool] = False) -> bool:
"""
Take a checkpoint asynchronously if `self.should_save(step)`. The
input state_dict will be transferred to the CPU device using the
`sharded_cpu_state_dict` function.
This function will do the following:
1. Transfer `state_dict` to the CPU device.
2. Dispatch the checkpoint workload to an asynchronous execution
queue. This will block training until the ongoing async
checkpoint finishes when the queue is full.
Args:
step: The current training step.
state_dict: The state dict to be checkpointed.
force: Option to force a checkpoint to be taken regardless of the result
of `should_save(step)`.
Returns:
True if a checkpoint was taken and False otherwise.
"""
raise NotImplementedError

def restore(self, step: int, state_dict: STATE_DICT_TYPE) -> None:
"""
Restores the checkpoint taken at the given step into the state_dict. The
caller is responsible for calling `model.load_state_dict` to restore any
non-tensor values.
Args:
step: The step whose checkpoint is to be restored.
state_dict: The state dict to restore the checkpoint into. Values are
updated in-place within the state_dict.
"""
raise NotImplementedError

def all_steps(self) -> List[int]:
"""
List all steps tracked by the CheckpointManager.
"""
raise NotImplementedError

0 comments on commit 4524543

Please sign in to comment.