Skip to content

Commit

Permalink
[FSDPv2] Fix test_fsdp_v2_multi_slice (#7055)
Browse files Browse the repository at this point in the history
Summary:
Fix test_fsdp_v2_multi_slice for TPU v2.

Test Plan:
CI.
  • Loading branch information
alanwaketan authored May 13, 2024
1 parent 6f0b61e commit 4a1588c
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions test/spmd/test_fsdp_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def test_fsdp_v2_multi_slice(self):

# Make sure all weights are sharded.
annotation = '{devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}'
if self.n_devices == 8:
annotation = '{devices=[4,1,2]0,4,1,5,2,6,3,7 last_tile_dim_replicate}'
self.assertEqual(annotation,
torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight))
self.assertEqual(annotation,
Expand All @@ -158,6 +160,8 @@ def test_fsdp_v2_multi_slice(self):
output = model(x)
# Make sure output are sharded.
annotation = '{devices=[4,1]0,1,2,3}'
if self.n_devices == 8:
annotation = '{devices=[8,1]0,1,2,3,4,5,6,7}'
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(x))
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(output))

Expand Down

0 comments on commit 4a1588c

Please sign in to comment.