Skip to content

Commit

Permalink
Support synchronous saving and loading in CheckpointManager (#5693)
Browse files Browse the repository at this point in the history
* Support synchronous saving and loading in CheckpointManager

* Use 0 to indicate no upper bound

* Don't track async_queue_size

* Cache tracked steps locally

* Track creation time in metadata

* Rename save_period to save_interval

* Fix tests
  • Loading branch information
jonb377 authored and bhavya01 committed Apr 22, 2024
1 parent 8b9ebae commit 7c3faf9
Show file tree
Hide file tree
Showing 2 changed files with 211 additions and 11 deletions.
96 changes: 95 additions & 1 deletion test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import os
import sys
import tempfile
Expand All @@ -15,11 +16,23 @@
create_default_local_save_plan,
create_default_global_save_plan,
)
from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner
from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner, CheckpointManager
from torch_xla.experimental.distributed_checkpoint._helpers import (
_sharded_cpu_state_dict, _CpuShards, _is_sharded_tensor)


# Wrapper to manage a temporary directory for the wrapped test
def run_with_tmpdir(f):

@functools.wraps(f)
def run(*args, **kwargs):
with tempfile.TemporaryDirectory() as tmpdir:
kwargs.setdefault('tmpdir', tmpdir)
f(*args, **kwargs)

return run


class DistributedCheckpointTestBase(test_xla_sharding_base.XlaShardingTest):

@classmethod
Expand Down Expand Up @@ -319,6 +332,87 @@ def test_sharded_cpu_state_dict(self):
self.assertTrue(param.device == torch.device("cpu"))


class CheckpointManagerTest(DistributedCheckpointTestBase):

def setUp(self):
super().setUp()
# Initialize the a minimal process group
dist.init_process_group(
backend='gloo', init_method='tcp://127.1:8932', world_size=1, rank=0)

def tearDown(self):
super().tearDown()
# Destroy the CPU process group after the test
dist.destroy_process_group()

@run_with_tmpdir
def test_manager_checkpointing(self, tmpdir):
chkpt_mgr = CheckpointManager(tmpdir, save_interval=10)
state_dict = self._get_sharded_model().state_dict()

# Take a checkpoint on step 0
self.assertTrue(chkpt_mgr.save(0, state_dict))

# Load the checkpoint into a new state_dict
new_state_dict = self._get_sharded_model().state_dict()
self.assertFalse(
any(
torch.allclose(v, new_state_dict[k])
for k, v in state_dict.items()))
chkpt_mgr.restore(0, new_state_dict)
self.assertTrue(
all(
torch.allclose(v, new_state_dict[k])
for k, v in state_dict.items()))

@run_with_tmpdir
def test_manager_step_tracking(self, tmpdir):
chkpt_mgr = CheckpointManager(tmpdir, save_interval=10)
state_dict = self._get_sharded_model().state_dict()

# No steps are being tracked initially
self.assertEqual(chkpt_mgr.all_steps(), [])

# Steps not divisible by 10 should not be saved
for step in range(1, 10):
self.assertFalse(chkpt_mgr.save(step, state_dict))
self.assertEqual(chkpt_mgr.all_steps(), [])

# Steps divisible by 10 should be saved
saved = set()
for step in range(0, 100, 10):
self.assertTrue(chkpt_mgr.save(step, state_dict))
saved.add(step)
self.assertEqual(set(chkpt_mgr.all_steps()), saved)

@run_with_tmpdir
def test_manager_max_to_keep(self, tmpdir):
chkpt_mgr = CheckpointManager(tmpdir, save_interval=10, max_to_keep=2)
state_dict = self._get_sharded_model().state_dict()

# No steps are being tracked initially
self.assertEqual(chkpt_mgr.all_steps(), [])

self.assertTrue(chkpt_mgr.save(10, state_dict))
self.assertEqual(set(chkpt_mgr.all_steps()), {10})

self.assertTrue(chkpt_mgr.save(20, state_dict))
self.assertEqual(set(chkpt_mgr.all_steps()), {10, 20})

# The oldest checkpoint should be erased
self.assertTrue(chkpt_mgr.save(30, state_dict))
self.assertEqual(set(chkpt_mgr.all_steps()), {30, 20})

# The oldest is selected by creation timestamp, not step
self.assertTrue(chkpt_mgr.save(10, state_dict))
self.assertEqual(set(chkpt_mgr.all_steps()), {30, 10})

# The deletion order should persist across executions
chkpt_mgr = CheckpointManager(tmpdir, save_interval=10, max_to_keep=2)
self.assertTrue(chkpt_mgr.save(20, state_dict))
self.assertEqual(set(chkpt_mgr.all_steps()), {20, 10})


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
126 changes: 116 additions & 10 deletions torch_xla/experimental/distributed_checkpoint/manager.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,35 @@
import fsspec
import logging
import os
import pickle
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
import torch_xla.runtime as xr
import torch_xla.experimental.distributed_checkpoint as xc

from typing import List, Optional
from dataclasses import dataclass
from datetime import datetime
from collections import deque
from fsspec.core import url_to_fs
from os.path import basename
from typing import Deque, List, Optional, Union
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE

# TODO(jonbolin): Import path will change
from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter

# File to track manager-specific metadata within each checkpoint path
_MANAGER_METADATA_FILE = '.manager_metadata'


@dataclass
class _CheckpointMetadata:
# The step at which the checkpoint was taken
step: int

# The time at which the checkpoint was taken
ts: datetime


class CheckpointManager:
"""
Expand Down Expand Up @@ -53,22 +79,33 @@ class CheckpointManager:
https://github.com/google/orbax/blob/efc079c4e5b437782a80138913d322cb3ed365c7/checkpoint/orbax/checkpoint/checkpoint_manager.py
"""

# The base path to write checkpoints to. Each checkpoint taken by the manager
# will be written into a subdirectory of this path, identified by the
# checkpoint's step.
base_path: Union[str, os.PathLike]

# The interval to take checkpoints, in steps.
save_interval: int

# The maximum number of checkpoints to keep.
max_to_keep: int

def __init__(self,
path: str,
save_period: int,
max_to_keep: Optional[int] = -1,
save_interval: int,
max_to_keep: Optional[int] = 0,
async_queue_size: Optional[int] = 1):
"""
Create a checkpoint manager that reads and writes checkpoints into
the provided directory.
Args:
path: The base path for the CheckpointManager to write checkpoints into.
save_period: The number of steps between saving checkpoints.
save_interval: The number of steps between saving checkpoints.
max_to_keep: The maximum number of checkpoints to be tracked by the
CheckpointManager. When a new checkpoint will be taken, the
checkpoint for the lowest tracked step will be deleted.
Default: -1, indicating no upper bound on the number of checkpoints.
Default: 0, indicating no upper bound on the number of checkpoints.
async_queue_size: The size of the execution queue which processes async
checkpoints. This should be a small value to ensure training doesn't
get too far ahead of the last finished checkpoint, but increasing
Expand All @@ -77,14 +114,61 @@ def __init__(self,
Default: 1, which only allows a single async checkpoint to be
pending at a time.
"""
raise NotImplementedError
assert dist.is_initialized(), "A process group is required."
assert save_interval > 0, "save_interval must be positive"
assert async_queue_size > 0, "async_queue_size must be positive"
assert max_to_keep >= 0, "max_to_keep must be non-negative"

self.base_path = path
self.save_interval = save_interval
self.max_to_keep = max_to_keep

self._tracked_chkpts = self._load_tracked_chkpts()

def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]:
"""
Loads a list of all tracked checkpoints from the storage backend.
"""
all_chkpts = []
invalid_paths = []
fs, raw_path = url_to_fs(self.base_path)
for path in fs.ls(raw_path, detail=False):
try:
with fsspec.open(os.path.join(path, _MANAGER_METADATA_FILE), 'rb') as f:
all_chkpts.append(pickle.load(f))
except:
invalid_paths.append(path)

if invalid_paths:
logging.warning(f'Ignoring invalid checkpoints: {invalid_paths}')
return deque(sorted(all_chkpts, key=lambda m: m.ts))

def _get_path(self, step: int) -> str:
return os.path.join(self.base_path, str(step))

def _delete_chkpt_at_step(self, step):
path = self._get_path(step)
fs, raw_path = url_to_fs(path)
if fs.exists(raw_path):
fs.rm(raw_path, recursive=True)

def _release_oldest_checkpoints(self):
"""
Delete oldest checkpoints until the number of tracked checkpoints is below
self.max_to_keep. This operation is only execution on the rank 0 process.
"""
if dist.get_rank() == 0 and self.max_to_keep > 0:
while len(self._tracked_chkpts) > self.max_to_keep:
oldest_chkpt = self._tracked_chkpts.popleft()
self._delete_chkpt_at_step(oldest_chkpt.step)

def should_save(self, step: int) -> bool:
"""
Returns true if a checkpoint should be saved for the current step or if
a preemption has been detected.
"""
raise NotImplementedError
# TODO(jonbolin): Support preemption notice for auto checkpointing
return step % self.save_interval == 0

def save(self,
step,
Expand All @@ -101,7 +185,22 @@ def save(self,
Returns:
True if a checkpoint was taken and False otherwise.
"""
raise NotImplementedError
if self.should_save(step) or force:
path = self._get_path(step)
# Delete any existing checkpoint at the current step.
self._delete_chkpt_at_step(step)
dist_cp.save_state_dict(
state_dict=state_dict,
storage_writer=FsspecWriter(path),
planner=xc.SPMDSavePlanner(),
)
metadata = _CheckpointMetadata(step=step, ts=datetime.now())
with fsspec.open(os.path.join(path, _MANAGER_METADATA_FILE), 'wb') as f:
pickle.dump(metadata, f)
self._tracked_chkpts.append(metadata)
self._release_oldest_checkpoints()
return True
return False

def save_async(self,
step: int,
Expand Down Expand Up @@ -139,10 +238,17 @@ def restore(self, step: int, state_dict: STATE_DICT_TYPE) -> None:
state_dict: The state dict to restore the checkpoint into. Values are
updated in-place within the state_dict.
"""
raise NotImplementedError
tracked_steps = set(x.step for x in self._tracked_chkpts)
assert step in tracked_steps, f'Cannot restore from untracked step {step}. Valid steps are: {tracked_steps}'
path = self._get_path(step)
dist_cp.load_state_dict(
state_dict=state_dict,
storage_reader=FsspecReader(path),
planner=xc.SPMDLoadPlanner(),
)

def all_steps(self) -> List[int]:
"""
List all steps tracked by the CheckpointManager.
"""
raise NotImplementedError
return sorted(x.step for x in self._tracked_chkpts)

0 comments on commit 7c3faf9

Please sign in to comment.