diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 276571e5979..f51d7dd95f5 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -1,3 +1,4 @@ +import functools import os import sys import tempfile @@ -15,11 +16,23 @@ create_default_local_save_plan, create_default_global_save_plan, ) -from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner +from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner, CheckpointManager from torch_xla.experimental.distributed_checkpoint._helpers import ( _sharded_cpu_state_dict, _CpuShards, _is_sharded_tensor) +# Wrapper to manage a temporary directory for the wrapped test +def run_with_tmpdir(f): + + @functools.wraps(f) + def run(*args, **kwargs): + assert 'tmpdir' not in kwargs + with tempfile.TemporaryDirectory() as tmpdir: + f(*args, **kwargs, tmpdir=tmpdir) + + return run + + class DistributedCheckpointTestBase(test_xla_sharding_base.XlaShardingTest): @classmethod @@ -319,6 +332,78 @@ def test_sharded_cpu_state_dict(self): self.assertTrue(param.device == torch.device("cpu")) +class CheckpointManagerTest(DistributedCheckpointTestBase): + + def setUp(self): + super().setUp() + # Initialize the a minimal process group + dist.init_process_group( + init_method='tcp://127.1:8932', world_size=1, rank=0) + + def tearDown(self): + super().tearDown() + # Destroy the CPU process group after the test + dist.destroy_process_group() + + @run_with_tmpdir + def test_manager_checkpointing(self, tmpdir): + chkpt_mgr = CheckpointManager(tmpdir, save_period=10) + state_dict = self._get_sharded_model().state_dict() + + # Take a checkpoint on step 0 + self.assertTrue(chkpt_mgr.save(0, state_dict)) + + # Load the checkpoint into a new state_dict + new_state_dict = self._get_sharded_model().state_dict() + self.assertFalse( + any( + torch.allclose(v, new_state_dict[k]) + for k, v in state_dict.items())) + chkpt_mgr.restore(0, new_state_dict) + self.assertTrue( + all( + torch.allclose(v, new_state_dict[k]) + for k, v in state_dict.items())) + + @run_with_tmpdir + def test_manager_step_tracking(self, tmpdir): + chkpt_mgr = CheckpointManager(tmpdir, save_period=10) + state_dict = self._get_sharded_model().state_dict() + + # No steps are being tracked initially + self.assertEqual(chkpt_mgr.all_steps(), []) + + # Steps not divisible by 10 should not be saved + for step in range(1, 10): + self.assertFalse(chkpt_mgr.save(step, state_dict)) + self.assertEqual(chkpt_mgr.all_steps(), []) + + # Steps divisible by 10 should be saved + saved = set() + for step in range(0, 100, 10): + self.assertTrue(chkpt_mgr.save(step, state_dict)) + saved.add(step) + self.assertEqual(set(chkpt_mgr.all_steps()), saved) + + @run_with_tmpdir + def test_manager_max_to_keep(self, tmpdir): + chkpt_mgr = CheckpointManager(tmpdir, save_period=10, max_to_keep=2) + state_dict = self._get_sharded_model().state_dict() + + # No steps are being tracked initially + self.assertEqual(chkpt_mgr.all_steps(), []) + + self.assertTrue(chkpt_mgr.save(10, state_dict)) + self.assertEqual(set(chkpt_mgr.all_steps()), {10}) + + self.assertTrue(chkpt_mgr.save(20, state_dict)) + self.assertEqual(set(chkpt_mgr.all_steps()), {10, 20}) + + # The oldest checkpoint should be erased + self.assertTrue(chkpt_mgr.save(30, state_dict)) + self.assertEqual(set(chkpt_mgr.all_steps()), {30, 20}) + + if __name__ == '__main__': test = unittest.main() sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py index cd36cbe1eb6..0232f2b8f95 100644 --- a/torch_xla/experimental/distributed_checkpoint/manager.py +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -1,9 +1,17 @@ +import os +import torch.distributed as dist import torch.distributed.checkpoint as dist_cp +import torch_xla.runtime as xr import torch_xla.experimental.distributed_checkpoint as xc -from typing import List, Optional +from fsspec.core import url_to_fs +from os.path import basename +from typing import List, Optional, Union from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +# TODO(jonbolin): Import path will change +from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter + class CheckpointManager: """ @@ -53,6 +61,14 @@ class CheckpointManager: https://github.com/google/orbax/blob/efc079c4e5b437782a80138913d322cb3ed365c7/checkpoint/orbax/checkpoint/checkpoint_manager.py """ + # The base path to write checkpoints to. Each checkpoint taken by the manager + # will be written into a subdirectory of this path, identified by the + # checkpoint's step. + base_path: Union[str, os.PathLike] + + # The period to take checkpoints, in steps. + save_period: int + def __init__(self, path: str, save_period: int, @@ -77,14 +93,36 @@ def __init__(self, Default: 1, which only allows a single async checkpoint to be pending at a time. """ - raise NotImplementedError + assert dist.is_initialized(), "A process group is required." + + self.base_path = path + self.save_period = save_period + self.max_to_keep = max_to_keep + self.async_queue_size = async_queue_size + assert self.save_period > 0, "save_period must be positive" + assert self.async_queue_size > 0, "async_queue_size must be positive" + assert self.max_to_keep != 0, "max_to_keep must be non-zero" + + def _get_path(self, step: int) -> str: + return os.path.join(self.base_path, str(step)) + + def _release_oldest_checkpoints(self): + if self.max_to_keep > 0: + tracked_steps = sorted(self.all_steps(), reverse=True) + while len(tracked_steps) > self.max_to_keep: + # Delete the oldest checkpoint step to free up space for the new one. + oldest_step = tracked_steps.pop() + path = self._get_path(oldest_step) + fs, raw_path = url_to_fs(path) + fs.rm(raw_path, recursive=True) def should_save(self, step: int) -> bool: """ Returns true if a checkpoint should be saved for the current step or if a preemption has been detected. """ - raise NotImplementedError + # TODO(jonbolin): Support preemption notice for auto checkpointing + return step % self.save_period == 0 def save(self, step, @@ -101,7 +139,16 @@ def save(self, Returns: True if a checkpoint was taken and False otherwise. """ - raise NotImplementedError + if self.should_save(step): + path = self._get_path(step) + dist_cp.save_state_dict( + state_dict=state_dict, + storage_writer=FsspecWriter(path), + planner=xc.SPMDSavePlanner(), + ) + self._release_oldest_checkpoints() + return True + return False def save_async(self, step: int, @@ -139,10 +186,18 @@ def restore(self, step: int, state_dict: STATE_DICT_TYPE) -> None: state_dict: The state dict to restore the checkpoint into. Values are updated in-place within the state_dict. """ - raise NotImplementedError + path = self._get_path(step) + dist_cp.load_state_dict( + state_dict=state_dict, + storage_reader=FsspecReader(path), + planner=xc.SPMDLoadPlanner(), + ) def all_steps(self) -> List[int]: """ List all steps tracked by the CheckpointManager. """ - raise NotImplementedError + fs, raw_path = url_to_fs(self.base_path) + all_paths = fs.ls(raw_path, detail=False) + all_steps = map(basename, all_paths) + return list(map(int, all_steps))