Skip to content

Commit

Permalink
Add missing import for XLA init_method (#6809)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 authored Mar 22, 2024
1 parent b0ceb2b commit 34d17c2
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def test_resharding_different_device_mesh(self):
def test_multihost_checkpoint(self):
torch.manual_seed(42)

# Initialize the default CPU process group from the environment.
# Initialize the default CPU process group.
import torch_xla.distributed.xla_backend
dist.init_process_group(backend='gloo', init_method='xla://')

model1 = self._get_sharded_model(mesh_shape=(1, self.n_devices))
Expand Down

0 comments on commit 34d17c2

Please sign in to comment.