Skip to content

Commit

Permalink
[ZeRO-1] Sync features from r2.1_aws_neuron branch (#7132)
Browse files Browse the repository at this point in the history
  • Loading branch information
hgt312 authored Jun 10, 2024
1 parent d0fb59e commit 551a76c
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 44 deletions.
2 changes: 1 addition & 1 deletion test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ function run_xla_op_tests1 {
run_test "$CDIR/test_python_ops.py"
run_test "$CDIR/test_ops.py"
run_test "$CDIR/test_metrics.py"
run_test "$CDIR/test_zero1.py"
run_test "$CDIR/dynamo/test_dynamo_integrations_util.py"
run_test "$CDIR/dynamo/test_dynamo_aliasing.py"
run_test "$CDIR/dynamo/test_dynamo.py"
Expand Down Expand Up @@ -296,6 +295,7 @@ function run_mp_op_tests {
run_test "$CDIR/test_mp_collective_permute.py"
run_test "$CDIR/test_mp_all_gather.py"
run_test "$CDIR/test_mp_reduce_scatter.py"
run_test "$CDIR/test_zero1.py"
run_test "$CDIR/test_mp_distributed_mm.py"
run_test "$CDIR/test_mp_save.py"
run_test "$CDIR/test_mp_mesh_reduce.py"
Expand Down
76 changes: 58 additions & 18 deletions test/test_zero1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,64 +2,104 @@
import torch.nn as nn
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer
from torch_xla import runtime as xr
from torch.testing._internal.common_utils import TestCase
from copy import deepcopy

import sys
import unittest

import test_utils

class XlaZeRO1Test(TestCase):

def _get_partial_states(s):
dp_size = xr.global_device_count()
dp_rank = xr.global_ordinal()

def convert_fn(tensors):
torch_xla._XLAC._xla_sync_multi(
tensors, devices=[], wait=True, sync_xla_data=True)
ret = []
for t in tensors:
ret.append(t.chunk(dp_size)[dp_rank].detach().cpu())
return ret

def select_fn(v):
return type(v) == torch.Tensor and xm.is_xla_tensor(v)

return xm.ToXlaTensorArena(convert_fn, select_fn).transform(s)


class XlaZeRO1Test(test_utils.XlaTestCase):

@unittest.skipIf(xr.device_type() == 'TPU', "Crash on TPU")
@unittest.skipIf(xr.device_type() == 'CUDA', "Crash on CUDA")
def test_zero1(self):
device = xm.xla_device()

model = nn.Linear(8, 8)
x = torch.ones((8, 8))
model = nn.Linear(32, 32)
x = torch.ones((32, 32))
x.requires_grad = True
model = model.to(device)
x = x.to(device)
y = model(x).sum()
y.backward()
xm.mark_step()

opt1 = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
opt1.step()
xm.mark_step()

opt2 = ZeroRedundancyOptimizer(
model.parameters(),
torch.optim.SGD,
lr=0.01,
momentum=0.9,
grad_clipping=False)

opt1.step()
opt2.step()
xm.mark_step()

s1 = opt1.state_dict()
s2 = opt2.state_dict()
self.assertEqual(s1['state'], s2['base_state'])
self.assertEqual(_get_partial_states(s1['state']), s2['base_state'])

# deepcopy s1 to load later because pytorch optimizers do not guarantee the input
# state_dict will not be modified. on the other hand, s2 has this guarantee.
s1_clone = deepcopy(s1)
s1_clone = deepcopy(xm._maybe_convert_to_cpu(s1))
s2_clone = deepcopy(xm._maybe_convert_to_cpu(s2))

opt1.load_state_dict(s1)
opt2.load_state_dict(s2)
self.assertEqual(opt1.state_dict()['state'],
opt2.state_dict()['base_state'])
self.assertEqual(
_get_partial_states(opt1.state_dict()['state']),
opt2.state_dict()['base_state'])

# step still runnable
opt1.step()
opt2.step()
xm.mark_step()

opt1.load_state_dict(s1_clone)
opt2.load_state_dict(s2)
self.assertEqual(opt1.state_dict()['state'],
opt2.state_dict()['base_state'])
opt2.load_state_dict(s2_clone)
xm.mark_step()
self.assertEqual(
_get_partial_states(opt1.state_dict()['state']),
opt2.state_dict()['base_state'])

# step still runnable
opt1.step()
opt2.step()
xm.mark_step()


def _mp_fn(index):
device = xm.xla_device()
if xm.xla_device_hw(device) in ('TPU', 'CUDA'):
test = unittest.main(exit=False)
sys.exit(0 if test.result.wasSuccessful() else 1)
else:
print(
'Default device {} is not a TPU or CUDA device'.format(device),
file=sys.stderr)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
xmp.spawn(_mp_fn, args=())
2 changes: 1 addition & 1 deletion torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ def reduce_scatter_bucketized(reduce_type,
see reduce_scatter for reduce_type, scale, scatter_dim, shard_count, groups, pin_layout
input_list: List of input tensors
output: Optional list of output torch.Tensor
bucket_cap_mb: Number of MegaBytes of the tensor bucket to fill before doing all-gather.
bucket_cap_mb: Number of MegaBytes of the tensor bucket to fill before doing reduce-scatter.
Returns:
A list of `torch.Tensors` with all the values reduced across replicas. Each process
Expand Down
Loading

0 comments on commit 551a76c

Please sign in to comment.