Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
hgt312 committed May 3, 2024
1 parent e0655ee commit 0a48474
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
35 changes: 29 additions & 6 deletions test/test_zero1.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,27 @@
import unittest


def _get_partial_states(s):
dp_size = xr.global_device_count()
dp_rank = xr.global_ordinal()

def convert_fn(tensors):
torch_xla._XLAC._xla_sync_multi(
tensors, devices=[], wait=True, sync_xla_data=True)
ret = []
for t in tensors:
ret.append(t.chunk(dp_size)[dp_rank])
return ret

def select_fn(v):
return type(v) == torch.Tensor and xm.is_xla_tensor(v)

return xm.ToXlaTensorArena(convert_fn, select_fn).transform(s)


class XlaZeRO1Test(TestCase):

@unittest.skipIf(xr.device_type() == 'TPU', "Crash on TPU")
@unittest.skipIf(xr.device_type() == 'CUDA', "Crash on CUDA")
def test_zero1(self):
device = xm.xla_device()

Expand All @@ -34,26 +51,32 @@ def test_zero1(self):

opt1.step()
opt2.step()

xm.mark_step()
s1 = opt1.state_dict()
s2 = opt2.state_dict()
self.assertEqual(s1['state'], s2['base_state'])
self.assertEqual(_get_partial_states(s1['state']), s2['base_state'])

# deepcopy s1 to load later because pytorch optimizers do not guarantee the input
# state_dict will not be modified. on the other hand, s2 has this guarantee.
s1_clone = deepcopy(s1)

opt1.load_state_dict(s1)
opt2.load_state_dict(s2)
self.assertEqual(opt1.state_dict()['state'],
opt2.state_dict()['base_state'])
self.assertEqual(
_get_partial_states(opt1.state_dict()['state']),
opt2.state_dict()['base_state'])

# step still runnable
opt1.step()
opt2.step()

xm.mark_step()
opt1.load_state_dict(s1_clone)
opt2.load_state_dict(s2)
self.assertEqual(opt1.state_dict()['state'],
opt2.state_dict()['base_state'])
self.assertEqual(
_get_partial_states(opt1.state_dict()['state']),
opt2.state_dict()['base_state'])

# step still runnable
opt1.step()
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/distributed/zero_redundancy_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr


class ZeroRedundancyOptimizer(Optimizer):
Expand Down Expand Up @@ -89,8 +90,8 @@ def __init__(

super().__init__(params, defaults)

self.global_world_size = xm.xrt_world_size()
self.global_rank = xm.get_ordinal()
self.global_world_size = xr.global_device_count()
self.global_rank = xr.global_ordinal()
self._sharding_groups = [list(range(self.global_world_size))
] if sharding_groups is None else sharding_groups
self._grad_norm_groups = grad_norm_groups
Expand Down

0 comments on commit 0a48474

Please sign in to comment.