diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 3952e2d01a3..b3254dc44fe 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -4,6 +4,7 @@ import tempfile import unittest import test_xla_sharding_base +import threading import torch import torch.distributed as dist @@ -413,7 +414,35 @@ def test_manager_max_to_keep(self, tmpdir): self.assertEqual(set(chkpt_mgr.all_steps()), {20, 10}) @run_with_tmpdir - def test_manager_async_checkpoint(self, tmpdir): + def test_manager_async(self, tmpdir): + chkpt_mgr = CheckpointManager(tmpdir, save_period=10) + state_dict = self._get_sharded_model().state_dict() + + # Patch the manager's save method to block until this thread signals. + cond = threading.Condition() + old_save = chkpt_mgr.save + + def patched_save(*args, **kwargs): + cond.wait() + old_save(*args, **kwargs) + + with unittest.mock.patch.object(chkpt_mgr, 'save', patched_save): + chkpt_mgr.save_async(10, state_dict) + + # No new steps should be tracked immediately after calling save_async + self.assertEqual(chkpt_mgr.all_steps(), []) + + # Trigger the actual checkpoint in the background thread and wait for + # completion. + with cond: + cond.notify() + chkpt_mgr.join() + + # The manager should track all steps which were asynchronously saved. + self.assertEqual(set(chkpt_mgr.all_steps()), {10}) + + @run_with_tmpdir + def test_manager_async_step_tracking(self, tmpdir): chkpt_mgr = CheckpointManager(tmpdir, save_period=10) state_dict = self._get_sharded_model().state_dict() @@ -430,12 +459,10 @@ def test_manager_async_checkpoint(self, tmpdir): self.assertTrue(chkpt_mgr.save_async(step, state_dict)) saved.add(step) - # Delete the checkpoint manager to block this thread until all pending - # async checkpoints are complete. - del chkpt_mgr + # Join to allow pending async checkpoints to complete + chkpt_mgr.join() # The manager should track all steps which were asynchronously saved. - chkpt_mgr = CheckpointManager(tmpdir, save_period=10) self.assertEqual(set(chkpt_mgr.all_steps()), saved) # Load a checkpoint into a new state_dict diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py index 0ad0f8901aa..572ed2be4af 100644 --- a/torch_xla/experimental/distributed_checkpoint/manager.py +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -98,7 +98,8 @@ def __init__(self, path: str, save_period: int, max_to_keep: Optional[int] = 0, - async_queue_size: Optional[int] = 1): + async_queue_size: Optional[int] = 1, + process_group: dist.ProcessGroup = None): """ Create a checkpoint manager that reads and writes checkpoints into the provided directory. @@ -117,6 +118,9 @@ def __init__(self, network issues which slow down the active checkpoint. Default: 1, which only allows a single async checkpoint to be pending at a time. + process_group: The process group to use when coordinating the checkpoint. + Default: None, in which case a subgroup of the default process + group will be created. """ assert dist.is_initialized(), "A process group is required." assert save_period > 0, "save_period must be positive" @@ -128,13 +132,15 @@ def __init__(self, self.max_to_keep = max_to_keep self._tracked_chkpts = self._load_tracked_chkpts() - self._async_queue = queue.Queue(maxsize=async_queue_size) - self._chkpt_thread = threading.Thread(target=self._async_worker, daemon=True) + self._alive = threading.Event() + self._alive.set() + self._chkpt_thread = threading.Thread( + target=self._async_worker, daemon=True) self._chkpt_thread.start() - # Create a CPU process group to coordinate the checkpoint. - self.pg = dist.new_group(backend='gloo') + # Create a new group if none is provided + self.pg = process_group or dist.new_group() def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]: """ @@ -155,14 +161,19 @@ def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]: return deque(sorted(all_chkpts, key=lambda m: m.ts)) def __del__(self): - # Ensure pending checkpoints are finished - self._async_queue.join() + self._alive.clear() + # Send a sentinel value to tell the worker to exit, and wait for pending + # checkpoints to complete. + self._async_queue.put((None, None)) + self._chkpt_thread.join() def _async_worker(self): - while True: + while self._alive.is_set(): try: - step, state_dict = self._async_queue.get() - self.save(step, state_dict, force=True) + item = self._async_queue.get() + if item: + step, state_dict = item + self.save(step, state_dict, force=True) except: traceback.print_exc() finally: @@ -284,3 +295,7 @@ def all_steps(self) -> List[int]: List all steps tracked by the CheckpointManager. """ return sorted(x.step for x in self._tracked_chkpts) + + def join(self): + """ Wait for all pending async checkpoints to complete. """ + self._async_queue.join()