Skip to content

Commit

Permalink
Support synchronous saving and loading in CheckpointManager
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Oct 10, 2023
1 parent 4524543 commit 4f88d9f
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 7 deletions.
87 changes: 86 additions & 1 deletion test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import os
import sys
import tempfile
Expand All @@ -15,11 +16,23 @@
create_default_local_save_plan,
create_default_global_save_plan,
)
from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner
from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner, CheckpointManager
from torch_xla.experimental.distributed_checkpoint._helpers import (
_sharded_cpu_state_dict, _CpuShards, _is_sharded_tensor)


# Wrapper to manage a temporary directory for the wrapped test
def run_with_tmpdir(f):

@functools.wraps(f)
def run(*args, **kwargs):
assert 'tmpdir' not in kwargs
with tempfile.TemporaryDirectory() as tmpdir:
f(*args, **kwargs, tmpdir=tmpdir)

return run


class DistributedCheckpointTestBase(test_xla_sharding_base.XlaShardingTest):

@classmethod
Expand Down Expand Up @@ -319,6 +332,78 @@ def test_sharded_cpu_state_dict(self):
self.assertTrue(param.device == torch.device("cpu"))


class CheckpointManagerTest(DistributedCheckpointTestBase):

def setUp(self):
super().setUp()
# Initialize the a minimal process group
dist.init_process_group(
init_method='tcp://127.1:8932', world_size=1, rank=0)

def tearDown(self):
super().tearDown()
# Destroy the CPU process group after the test
dist.destroy_process_group()

@run_with_tmpdir
def test_manager_checkpointing(self, tmpdir):
chkpt_mgr = CheckpointManager(tmpdir, save_period=10)
state_dict = self._get_sharded_model().state_dict()

# Take a checkpoint on step 0
self.assertTrue(chkpt_mgr.save(0, state_dict))

# Load the checkpoint into a new state_dict
new_state_dict = self._get_sharded_model().state_dict()
self.assertFalse(
any(
torch.allclose(v, new_state_dict[k])
for k, v in state_dict.items()))
chkpt_mgr.restore(0, new_state_dict)
self.assertTrue(
all(
torch.allclose(v, new_state_dict[k])
for k, v in state_dict.items()))

@run_with_tmpdir
def test_manager_step_tracking(self, tmpdir):
chkpt_mgr = CheckpointManager(tmpdir, save_period=10)
state_dict = self._get_sharded_model().state_dict()

# No steps are being tracked initially
self.assertEqual(chkpt_mgr.all_steps(), [])

# Steps not divisible by 10 should not be saved
for step in range(1, 10):
self.assertFalse(chkpt_mgr.save(step, state_dict))
self.assertEqual(chkpt_mgr.all_steps(), [])

# Steps divisible by 10 should be saved
saved = set()
for step in range(0, 100, 10):
self.assertTrue(chkpt_mgr.save(step, state_dict))
saved.add(step)
self.assertEqual(set(chkpt_mgr.all_steps()), saved)

@run_with_tmpdir
def test_manager_max_to_keep(self, tmpdir):
chkpt_mgr = CheckpointManager(tmpdir, save_period=10, max_to_keep=2)
state_dict = self._get_sharded_model().state_dict()

# No steps are being tracked initially
self.assertEqual(chkpt_mgr.all_steps(), [])

self.assertTrue(chkpt_mgr.save(10, state_dict))
self.assertEqual(set(chkpt_mgr.all_steps()), {10})

self.assertTrue(chkpt_mgr.save(20, state_dict))
self.assertEqual(set(chkpt_mgr.all_steps()), {10, 20})

# The oldest checkpoint should be erased
self.assertTrue(chkpt_mgr.save(30, state_dict))
self.assertEqual(set(chkpt_mgr.all_steps()), {30, 20})


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
67 changes: 61 additions & 6 deletions torch_xla/experimental/distributed_checkpoint/manager.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import os
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
import torch_xla.runtime as xr
import torch_xla.experimental.distributed_checkpoint as xc

from typing import List, Optional
from fsspec.core import url_to_fs
from os.path import basename
from typing import List, Optional, Union
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE

# TODO(jonbolin): Import path will change
from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter


class CheckpointManager:
"""
Expand Down Expand Up @@ -53,6 +61,14 @@ class CheckpointManager:
https://github.com/google/orbax/blob/efc079c4e5b437782a80138913d322cb3ed365c7/checkpoint/orbax/checkpoint/checkpoint_manager.py
"""

# The base path to write checkpoints to. Each checkpoint taken by the manager
# will be written into a subdirectory of this path, identified by the
# checkpoint's step.
base_path: Union[str, os.PathLike]

# The period to take checkpoints, in steps.
save_period: int

def __init__(self,
path: str,
save_period: int,
Expand All @@ -77,14 +93,36 @@ def __init__(self,
Default: 1, which only allows a single async checkpoint to be
pending at a time.
"""
raise NotImplementedError
assert dist.is_initialized(), "A process group is required."

self.base_path = path
self.save_period = save_period
self.max_to_keep = max_to_keep
self.async_queue_size = async_queue_size
assert self.save_period > 0, "save_period must be positive"
assert self.async_queue_size > 0, "async_queue_size must be positive"
assert self.max_to_keep != 0, "max_to_keep must be non-zero"

def _get_path(self, step: int) -> str:
return os.path.join(self.base_path, str(step))

def _release_oldest_checkpoints(self):
if self.max_to_keep > 0:
tracked_steps = sorted(self.all_steps(), reverse=True)
while len(tracked_steps) > self.max_to_keep:
# Delete the oldest checkpoint step to free up space for the new one.
oldest_step = tracked_steps.pop()
path = self._get_path(oldest_step)
fs, raw_path = url_to_fs(path)
fs.rm(raw_path, recursive=True)

def should_save(self, step: int) -> bool:
"""
Returns true if a checkpoint should be saved for the current step or if
a preemption has been detected.
"""
raise NotImplementedError
# TODO(jonbolin): Support preemption notice for auto checkpointing
return step % self.save_period == 0

def save(self,
step,
Expand All @@ -101,7 +139,16 @@ def save(self,
Returns:
True if a checkpoint was taken and False otherwise.
"""
raise NotImplementedError
if self.should_save(step):
path = self._get_path(step)
dist_cp.save_state_dict(
state_dict=state_dict,
storage_writer=FsspecWriter(path),
planner=xc.SPMDSavePlanner(),
)
self._release_oldest_checkpoints()
return True
return False

def save_async(self,
step: int,
Expand Down Expand Up @@ -139,10 +186,18 @@ def restore(self, step: int, state_dict: STATE_DICT_TYPE) -> None:
state_dict: The state dict to restore the checkpoint into. Values are
updated in-place within the state_dict.
"""
raise NotImplementedError
path = self._get_path(step)
dist_cp.load_state_dict(
state_dict=state_dict,
storage_reader=FsspecReader(path),
planner=xc.SPMDLoadPlanner(),
)

def all_steps(self) -> List[int]:
"""
List all steps tracked by the CheckpointManager.
"""
raise NotImplementedError
fs, raw_path = url_to_fs(self.base_path)
all_paths = fs.ls(raw_path, detail=False)
all_steps = map(basename, all_paths)
return list(map(int, all_steps))

0 comments on commit 4f88d9f

Please sign in to comment.