diff --git a/test/spmd/test_fsdp_v2.py b/test/spmd/test_fsdp_v2.py index 0f8d1a7d954..81ccce13c42 100644 --- a/test/spmd/test_fsdp_v2.py +++ b/test/spmd/test_fsdp_v2.py @@ -142,23 +142,23 @@ def test_fsdp_v2_cpu_model(self): @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") def test_fsdp_v2_multi_slice(self): model = self.SimpleLinear().to(xm.xla_device()) - mesh = self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor')) + mesh = self._get_mesh((2, self.n_devices // 2, 1), None, + ('data', 'fsdp', 'tensor')) model = FSDPv2(model, mesh=mesh, extra_data_axis="data") # Make sure all weights are sharded. annotation = '{devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}' self.assertEqual(annotation, - torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)) + torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)) self.assertEqual(annotation, - torch_xla._XLAC._get_xla_sharding_spec(model.fc2.weight)) + torch_xla._XLAC._get_xla_sharding_spec(model.fc2.weight)) x = torch.randn(16, 128).to(xm.xla_device()) xs.mark_sharding(x, mesh, (('data', 'fsdp'), None)) output = model(x) # Make sure output are sharded. annotation = '{devices=[4,1]0,2,1,3}' - self.assertEqual(annotation, - torch_xla._XLAC._get_xla_sharding_spec(output)) + self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(output)) # Make sure the model can execute without error. xm.mark_step() @@ -169,7 +169,8 @@ def test_fsdp_v2_multi_slice_output_correctness(self): model_expected = self.SimpleLinear().to(xm.xla_device()) model = copy.deepcopy(model_expected) - mesh = self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor')) + mesh = self._get_mesh((2, self.n_devices // 2, 1), None, + ('data', 'fsdp', 'tensor')) model = FSDPv2(model, mesh=mesh, extra_data_axis="data") x_expected = torch.randn(16, 128).to(xm.xla_device()) @@ -183,9 +184,12 @@ def test_fsdp_v2_multi_slice_output_correctness(self): def test_fsdp_v2_multi_slice_error(self): model = self.SimpleLinear().to(xm.xla_device()) - xs.set_global_mesh(self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor'))) + xs.set_global_mesh( + self._get_mesh((2, self.n_devices // 2, 1), None, + ('data', 'fsdp', 'tensor'))) - with self.assertRaisesRegex(ValueError, "The provided ddp axis is not in the mesh."): + with self.assertRaisesRegex(ValueError, + "The provided ddp axis is not in the mesh."): model = FSDPv2(model, extra_data_axis='ddp') diff --git a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py index 6cc7e41ccd4..bb01204508d 100644 --- a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py +++ b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py @@ -79,7 +79,8 @@ def __init__( if "fsdp" not in mesh.axis_names: raise ValueError("The mesh must have an axis named 'fsdp'.") if extra_data_axis and extra_data_axis not in mesh.axis_names: - raise ValueError(f"The provided {extra_data_axis} axis is not in the mesh.") + raise ValueError( + f"The provided {extra_data_axis} axis is not in the mesh.") super().__init__() @@ -136,8 +137,9 @@ def shard_output_impl(output, mesh): f"The output type is not supported: {type(output)}. Please provide your own shard_output callable." ) - spmd.mark_sharding(real_output, mesh, - _prepare_spmd_partition_spec(real_output, extra_data_axis)) + spmd.mark_sharding( + real_output, mesh, + _prepare_spmd_partition_spec(real_output, extra_data_axis)) shard_output = shard_output_impl