Skip to content

Commit

Permalink
Fix linters
Browse files Browse the repository at this point in the history
  • Loading branch information
alanwaketan committed May 10, 2024
1 parent 6e26935 commit 39105fa
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
20 changes: 12 additions & 8 deletions test/spmd/test_fsdp_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,23 +142,23 @@ def test_fsdp_v2_cpu_model(self):
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_fsdp_v2_multi_slice(self):
model = self.SimpleLinear().to(xm.xla_device())
mesh = self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor'))
mesh = self._get_mesh((2, self.n_devices // 2, 1), None,
('data', 'fsdp', 'tensor'))
model = FSDPv2(model, mesh=mesh, extra_data_axis="data")

# Make sure all weights are sharded.
annotation = '{devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}'
self.assertEqual(annotation,
torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight))
torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight))
self.assertEqual(annotation,
torch_xla._XLAC._get_xla_sharding_spec(model.fc2.weight))
torch_xla._XLAC._get_xla_sharding_spec(model.fc2.weight))

x = torch.randn(16, 128).to(xm.xla_device())
xs.mark_sharding(x, mesh, (('data', 'fsdp'), None))
output = model(x)
# Make sure output are sharded.
annotation = '{devices=[4,1]0,2,1,3}'
self.assertEqual(annotation,
torch_xla._XLAC._get_xla_sharding_spec(output))
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(output))

# Make sure the model can execute without error.
xm.mark_step()
Expand All @@ -169,7 +169,8 @@ def test_fsdp_v2_multi_slice_output_correctness(self):
model_expected = self.SimpleLinear().to(xm.xla_device())

model = copy.deepcopy(model_expected)
mesh = self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor'))
mesh = self._get_mesh((2, self.n_devices // 2, 1), None,
('data', 'fsdp', 'tensor'))
model = FSDPv2(model, mesh=mesh, extra_data_axis="data")

x_expected = torch.randn(16, 128).to(xm.xla_device())
Expand All @@ -183,9 +184,12 @@ def test_fsdp_v2_multi_slice_output_correctness(self):

def test_fsdp_v2_multi_slice_error(self):
model = self.SimpleLinear().to(xm.xla_device())
xs.set_global_mesh(self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor')))
xs.set_global_mesh(
self._get_mesh((2, self.n_devices // 2, 1), None,
('data', 'fsdp', 'tensor')))

with self.assertRaisesRegex(ValueError, "The provided ddp axis is not in the mesh."):
with self.assertRaisesRegex(ValueError,
"The provided ddp axis is not in the mesh."):
model = FSDPv2(model, extra_data_axis='ddp')


Expand Down
8 changes: 5 additions & 3 deletions torch_xla/experimental/spmd_fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def __init__(
if "fsdp" not in mesh.axis_names:
raise ValueError("The mesh must have an axis named 'fsdp'.")
if extra_data_axis and extra_data_axis not in mesh.axis_names:
raise ValueError(f"The provided {extra_data_axis} axis is not in the mesh.")
raise ValueError(
f"The provided {extra_data_axis} axis is not in the mesh.")

super().__init__()

Expand Down Expand Up @@ -136,8 +137,9 @@ def shard_output_impl(output, mesh):
f"The output type is not supported: {type(output)}. Please provide your own shard_output callable."
)

spmd.mark_sharding(real_output, mesh,
_prepare_spmd_partition_spec(real_output, extra_data_axis))
spmd.mark_sharding(
real_output, mesh,
_prepare_spmd_partition_spec(real_output, extra_data_axis))

shard_output = shard_output_impl

Expand Down

0 comments on commit 39105fa

Please sign in to comment.