Skip to content

Commit

Permalink
Use 0 to indicate no upper bound
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Oct 10, 2023
1 parent 4f88d9f commit 7326cc8
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions torch_xla/experimental/distributed_checkpoint/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,16 @@ class CheckpointManager:
# The period to take checkpoints, in steps.
save_period: int

# The maximum number of checkpoints to keep.
max_to_keep: int

# The size of the queue which processes async checkpoints.
async_queue_size: int

def __init__(self,
path: str,
save_period: int,
max_to_keep: Optional[int] = -1,
max_to_keep: Optional[int] = 0,
async_queue_size: Optional[int] = 1):
"""
Create a checkpoint manager that reads and writes checkpoints into
Expand All @@ -84,7 +90,7 @@ def __init__(self,
max_to_keep: The maximum number of checkpoints to be tracked by the
CheckpointManager. When a new checkpoint will be taken, the
checkpoint for the lowest tracked step will be deleted.
Default: -1, indicating no upper bound on the number of checkpoints.
Default: 0, indicating no upper bound on the number of checkpoints.
async_queue_size: The size of the execution queue which processes async
checkpoints. This should be a small value to ensure training doesn't
get too far ahead of the last finished checkpoint, but increasing
Expand All @@ -101,7 +107,7 @@ def __init__(self,
self.async_queue_size = async_queue_size
assert self.save_period > 0, "save_period must be positive"
assert self.async_queue_size > 0, "async_queue_size must be positive"
assert self.max_to_keep != 0, "max_to_keep must be non-zero"
assert self.max_to_keep >= 0, "max_to_keep must be non-negative"

def _get_path(self, step: int) -> str:
return os.path.join(self.base_path, str(step))
Expand All @@ -110,7 +116,7 @@ def _release_oldest_checkpoints(self):
if self.max_to_keep > 0:
tracked_steps = sorted(self.all_steps(), reverse=True)
while len(tracked_steps) > self.max_to_keep:
# Delete the oldest checkpoint step to free up space for the new one.
# Delete the oldest checkpoint step
oldest_step = tracked_steps.pop()
path = self._get_path(oldest_step)
fs, raw_path = url_to_fs(path)
Expand Down

0 comments on commit 7326cc8

Please sign in to comment.