diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py index a889d2454b2..f002b78cb6a 100644 --- a/torch_xla/experimental/distributed_checkpoint/manager.py +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -53,7 +53,11 @@ class CheckpointManager: https://github.com/google/orbax/blob/efc079c4e5b437782a80138913d322cb3ed365c7/checkpoint/orbax/checkpoint/checkpoint_manager.py """ - def __init__(self, path: str, save_period: int): + def __init__(self, + path: str, + save_period: int, + max_to_keep: Optional[int] = -1, + async_queue_size: Optional[int] = 1): """ Create a checkpoint manager that reads and writes checkpoints into the provided directory. @@ -61,6 +65,17 @@ def __init__(self, path: str, save_period: int): Args: path: The base path for the CheckpointManager to write checkpoints into. save_period: The number of steps between saving checkpoints. + max_to_keep: The maximum number of checkpoints to be tracked by the + CheckpointManager. When a new checkpoint will be taken, the + checkpoint for the lowest tracked step will be deleted. + Default: -1, indicating no upper bound on the number of checkpoints. + async_queue_size: The size of the execution queue which processes async + checkpoints. This should be a small value to ensure training doesn't + get too far ahead of the last finished checkpoint, but increasing + the value to 2 can unblock training when there are transient + network issues which slow down the active checkpoint. + Default: 1, which only allows a single async checkpoint to be + pending at a time. """ raise NotImplementedError @@ -82,7 +97,7 @@ def save(self, step: The current training step. state_dict: The state dict to be checkpointed. force: Option to force a checkpoint to be taken regardless of the result - of `should_save(step)` + of `should_save(step)`. Returns: True if a checkpoint was taken and False otherwise. """ @@ -99,14 +114,15 @@ def save_async(self, This function will do the following: 1. Transfer `state_dict` to the CPU device. - 2. Synchronously wait for any other async checkpoints to finish. - 3. Start a background thread to take the checkpoint asynchronously. + 2. Dispatch the checkpoint workload to an asynchronous execution + queue. This will block training until the ongoing async + checkpoint finishes when the queue is full. Args: step: The current training step. state_dict: The state dict to be checkpointed. force: Option to force a checkpoint to be taken regardless of the result - of `should_save(step)` + of `should_save(step)`. Returns: True if a checkpoint was taken and False otherwise. """