Skip to content

Commit

Permalink
Add TODO
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Oct 13, 2023
1 parent 702fbb5 commit fdecad3
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
4 changes: 2 additions & 2 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def test_manager_max_to_keep(self, tmpdir):

@run_with_tmpdir
def test_manager_async(self, tmpdir):
chkpt_mgr = CheckpointManager(tmpdir, save_period=10)
chkpt_mgr = CheckpointManager(tmpdir, save_interval=10)
state_dict = self._get_sharded_model().state_dict()

# Patch the manager's save method to block until this thread signals.
Expand Down Expand Up @@ -443,7 +443,7 @@ def patched_save(*args, **kwargs):

@run_with_tmpdir
def test_manager_async_step_tracking(self, tmpdir):
chkpt_mgr = CheckpointManager(tmpdir, save_period=10)
chkpt_mgr = CheckpointManager(tmpdir, save_interval=10)
state_dict = self._get_sharded_model().state_dict()

self.assertEqual(chkpt_mgr.all_steps(), [])
Expand Down
1 change: 1 addition & 0 deletions torch_xla/experimental/distributed_checkpoint/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(self,
self._chkpt_thread.start()

# Create a new group if none is provided
# TODO(jonbolin): Verify subgroup on GPU backend
self.pg = process_group or dist.new_group()

def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]:
Expand Down

0 comments on commit fdecad3

Please sign in to comment.