Skip to content

Commit

Permalink
Update documentation for async
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Oct 9, 2023
1 parent 305620f commit 46f9706
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions torch_xla/experimental/distributed_checkpoint/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,29 @@ class CheckpointManager:
https://github.com/google/orbax/blob/efc079c4e5b437782a80138913d322cb3ed365c7/checkpoint/orbax/checkpoint/checkpoint_manager.py
"""

def __init__(self, path: str, save_period: int):
def __init__(self,
path: str,
save_period: int,
max_to_keep: Optional[int] = -1,
async_queue_size: Optional[int] = 1):
"""
Create a checkpoint manager that reads and writes checkpoints into
the provided directory.
Args:
path: The base path for the CheckpointManager to write checkpoints into.
save_period: The number of steps between saving checkpoints.
max_to_keep: The maximum number of checkpoints to be tracked by the
CheckpointManager. When a new checkpoint will be taken, the
checkpoint for the lowest tracked step will be deleted.
Default: -1, indicating no upper bound on the number of checkpoints.
async_queue_size: The size of the execution queue which processes async
checkpoints. This should be a small value to ensure training doesn't
get too far ahead of the last finished checkpoint, but increasing
the value to 2 can unblock training when there are transient
network issues which slow down the active checkpoint.
Default: 1, which only allows a single async checkpoint to be
pending at a time.
"""
raise NotImplementedError

Expand All @@ -82,7 +97,7 @@ def save(self,
step: The current training step.
state_dict: The state dict to be checkpointed.
force: Option to force a checkpoint to be taken regardless of the result
of `should_save(step)`
of `should_save(step)`.
Returns:
True if a checkpoint was taken and False otherwise.
"""
Expand All @@ -99,14 +114,15 @@ def save_async(self,
This function will do the following:
1. Transfer `state_dict` to the CPU device.
2. Synchronously wait for any other async checkpoints to finish.
3. Start a background thread to take the checkpoint asynchronously.
2. Dispatch the checkpoint workload to an asynchronous execution
queue. This will block training until the ongoing async
checkpoint finishes when the queue is full.
Args:
step: The current training step.
state_dict: The state dict to be checkpointed.
force: Option to force a checkpoint to be taken regardless of the result
of `should_save(step)`
of `should_save(step)`.
Returns:
True if a checkpoint was taken and False otherwise.
"""
Expand Down

0 comments on commit 46f9706

Please sign in to comment.