diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 044cc968c0b..37c0224a0f4 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -415,7 +415,7 @@ def test_manager_max_to_keep(self, tmpdir): @run_with_tmpdir def test_manager_async(self, tmpdir): - chkpt_mgr = CheckpointManager(tmpdir, save_period=10) + chkpt_mgr = CheckpointManager(tmpdir, save_interval=10) state_dict = self._get_sharded_model().state_dict() # Patch the manager's save method to block until this thread signals. @@ -443,7 +443,7 @@ def patched_save(*args, **kwargs): @run_with_tmpdir def test_manager_async_step_tracking(self, tmpdir): - chkpt_mgr = CheckpointManager(tmpdir, save_period=10) + chkpt_mgr = CheckpointManager(tmpdir, save_interval=10) state_dict = self._get_sharded_model().state_dict() self.assertEqual(chkpt_mgr.all_steps(), []) diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py index 476bae5b882..9e5cde711b8 100644 --- a/torch_xla/experimental/distributed_checkpoint/manager.py +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -140,6 +140,7 @@ def __init__(self, self._chkpt_thread.start() # Create a new group if none is provided + # TODO(jonbolin): Verify subgroup on GPU backend self.pg = process_group or dist.new_group() def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]: