Skip to content

Commit

Permalink
Update test_xla_sharding.py to take into account the number of devices (
Browse files Browse the repository at this point in the history
  • Loading branch information
bhavya01 authored Sep 12, 2024
1 parent 0535c88 commit fc03aef
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,18 +1272,19 @@ def test_spmd_all_reduce(self):
f"all-reduce(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.3",
hlo)

expected_x = torch.ones(8, 8) * 4
expected_x = torch.ones(8, 8) * self.n_devices
self.assertTrue(torch.allclose(x.cpu(), expected_x))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"Only runs on TPUv4")
def test_spmd_all_reduce_scale(self):
xs.set_global_mesh(self._get_mesh((1, self.n_devices)))
x = torch.ones(8, 8).to(xm.xla_device())
scale = 0.25

# all reduce
x = xs.enable_manual_sharding(x, (None, None)).global_tensor
x = torch_xla._XLAC._xla_spmd_all_reduce(xm.REDUCE_SUM, x, 0.25,
x = torch_xla._XLAC._xla_spmd_all_reduce(xm.REDUCE_SUM, x, scale,
[self.device_ids])
x = xs.disable_manual_sharding(x, (None, None), x.shape).global_tensor

Expand All @@ -1292,7 +1293,7 @@ def test_spmd_all_reduce_scale(self):
f"all-reduce(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.3",
hlo)

expected_x = torch.ones(8, 8)
expected_x = torch.ones(8, 8) * int(self.n_devices * scale)
self.assertTrue(torch.allclose(x.cpu(), expected_x))

def test_get_1d_mesh(self):
Expand Down

0 comments on commit fc03aef

Please sign in to comment.