diff --git a/test/test_zero1.py b/test/test_zero1.py index 17c46617973c..7fec8286f12c 100644 --- a/test/test_zero1.py +++ b/test/test_zero1.py @@ -10,10 +10,27 @@ import unittest +def _get_partial_states(s): + dp_size = xr.global_device_count() + dp_rank = xr.global_ordinal() + + def convert_fn(tensors): + torch_xla._XLAC._xla_sync_multi( + tensors, devices=[], wait=True, sync_xla_data=True) + ret = [] + for t in tensors: + ret.append(t.chunk(dp_size)[dp_rank]) + return ret + + def select_fn(v): + return type(v) == torch.Tensor and xm.is_xla_tensor(v) + + return xm.ToXlaTensorArena(convert_fn, select_fn).transform(s) + + class XlaZeRO1Test(TestCase): @unittest.skipIf(xr.device_type() == 'TPU', "Crash on TPU") - @unittest.skipIf(xr.device_type() == 'CUDA', "Crash on CUDA") def test_zero1(self): device = xm.xla_device() @@ -34,9 +51,11 @@ def test_zero1(self): opt1.step() opt2.step() + + xm.mark_step() s1 = opt1.state_dict() s2 = opt2.state_dict() - self.assertEqual(s1['state'], s2['base_state']) + self.assertEqual(_get_partial_states(s1['state']), s2['base_state']) # deepcopy s1 to load later because pytorch optimizers do not guarantee the input # state_dict will not be modified. on the other hand, s2 has this guarantee. @@ -44,16 +63,20 @@ def test_zero1(self): opt1.load_state_dict(s1) opt2.load_state_dict(s2) - self.assertEqual(opt1.state_dict()['state'], - opt2.state_dict()['base_state']) + self.assertEqual( + _get_partial_states(opt1.state_dict()['state']), + opt2.state_dict()['base_state']) # step still runnable opt1.step() opt2.step() + + xm.mark_step() opt1.load_state_dict(s1_clone) opt2.load_state_dict(s2) - self.assertEqual(opt1.state_dict()['state'], - opt2.state_dict()['base_state']) + self.assertEqual( + _get_partial_states(opt1.state_dict()['state']), + opt2.state_dict()['base_state']) # step still runnable opt1.step() diff --git a/torch_xla/distributed/zero_redundancy_optimizer.py b/torch_xla/distributed/zero_redundancy_optimizer.py index 2299714271a4..00e626f7d975 100644 --- a/torch_xla/distributed/zero_redundancy_optimizer.py +++ b/torch_xla/distributed/zero_redundancy_optimizer.py @@ -10,6 +10,7 @@ import torch_xla import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr class ZeroRedundancyOptimizer(Optimizer): @@ -89,8 +90,8 @@ def __init__( super().__init__(params, defaults) - self.global_world_size = xm.xrt_world_size() - self.global_rank = xm.get_ordinal() + self.global_world_size = xr.global_device_count() + self.global_rank = xr.global_ordinal() self._sharding_groups = [list(range(self.global_world_size)) ] if sharding_groups is None else sharding_groups self._grad_norm_groups = grad_norm_groups