diff --git a/test/spmd/test_fsdp_v2.py b/test/spmd/test_fsdp_v2.py index 08b588c9edc..1429e377b18 100644 --- a/test/spmd/test_fsdp_v2.py +++ b/test/spmd/test_fsdp_v2.py @@ -148,6 +148,8 @@ def test_fsdp_v2_multi_slice(self): # Make sure all weights are sharded. annotation = '{devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}' + if self.n_devices == 8: + annotation = '{devices=[4,1,2]0,4,1,5,2,6,3,7 last_tile_dim_replicate}' self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)) self.assertEqual(annotation, @@ -158,6 +160,8 @@ def test_fsdp_v2_multi_slice(self): output = model(x) # Make sure output are sharded. annotation = '{devices=[4,1]0,1,2,3}' + if self.n_devices == 8: + annotation = '{devices=[8,1]0,1,2,3,4,5,6,7}' self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(x)) self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(output))