From 5f8063af26cae408ffa5615fa2109ed47322b76c Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Fri, 10 May 2024 01:13:26 +0000 Subject: [PATCH 1/4] initial commit --- test/spmd/test_fsdp_v2.py | 65 ++++++++++++++++--- .../spmd_fully_sharded_data_parallel.py | 10 ++- 2 files changed, 65 insertions(+), 10 deletions(-) diff --git a/test/spmd/test_fsdp_v2.py b/test/spmd/test_fsdp_v2.py index ae997892547..0f8d1a7d954 100644 --- a/test/spmd/test_fsdp_v2.py +++ b/test/spmd/test_fsdp_v2.py @@ -26,9 +26,9 @@ def setUpClass(cls): def test_fsdp_v2_basic(self): model = self.SimpleLinear().to(xm.xla_device()) mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor')) - model.fc1 = FSDPv2(model.fc1, mesh) - model.fc2 = FSDPv2(model.fc2, mesh) - model = FSDPv2(model, mesh) + model.fc1 = FSDPv2(model.fc1, mesh=mesh) + model.fc2 = FSDPv2(model.fc2, mesh=mesh) + model = FSDPv2(model, mesh=mesh) # Make sure all weights are sharded. if self.n_devices > 1: @@ -67,9 +67,9 @@ def test_fsdp_v2_output_correctness(self): model = copy.deepcopy(model_expected) mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor')) - model.fc1 = FSDPv2(model.fc1, mesh) - model.fc2 = FSDPv2(model.fc2, mesh) - model = FSDPv2(model, mesh) + model.fc1 = FSDPv2(model.fc1, mesh=mesh) + model.fc2 = FSDPv2(model.fc2, mesh=mesh) + model = FSDPv2(model, mesh=mesh) x_expected = torch.randn(16, 128).to(xm.xla_device()) @@ -87,7 +87,7 @@ def test_fsdp_v2_auto_wrap_basic(self): transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Linear}, ) - model = FSDPv2(model, mesh, auto_wrap_policy=auto_wrap_policy) + model = FSDPv2(model, mesh=mesh, auto_wrap_policy=auto_wrap_policy) self.assertTrue(isinstance(model.fc1, FSDPv2)) self.assertTrue(isinstance(model.fc2, FSDPv2)) @@ -106,7 +106,7 @@ def auto_wrapper_callable(m, *args, **kwargs): model = FSDPv2( model, - mesh, + mesh=mesh, auto_wrap_policy=auto_wrap_policy, auto_wrapper_callable=auto_wrapper_callable) @@ -139,6 +139,55 @@ def test_fsdp_v2_cpu_model(self): self.assertEqual( str(list(model._orig_module.parameters())[0].device), "xla:0") + @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") + def test_fsdp_v2_multi_slice(self): + model = self.SimpleLinear().to(xm.xla_device()) + mesh = self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor')) + model = FSDPv2(model, mesh=mesh, extra_data_axis="data") + + # Make sure all weights are sharded. + annotation = '{devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}' + self.assertEqual(annotation, + torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)) + self.assertEqual(annotation, + torch_xla._XLAC._get_xla_sharding_spec(model.fc2.weight)) + + x = torch.randn(16, 128).to(xm.xla_device()) + xs.mark_sharding(x, mesh, (('data', 'fsdp'), None)) + output = model(x) + # Make sure output are sharded. + annotation = '{devices=[4,1]0,2,1,3}' + self.assertEqual(annotation, + torch_xla._XLAC._get_xla_sharding_spec(output)) + + # Make sure the model can execute without error. + xm.mark_step() + xm.wait_device_ops() + + @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") + def test_fsdp_v2_multi_slice_output_correctness(self): + model_expected = self.SimpleLinear().to(xm.xla_device()) + + model = copy.deepcopy(model_expected) + mesh = self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor')) + model = FSDPv2(model, mesh=mesh, extra_data_axis="data") + + x_expected = torch.randn(16, 128).to(xm.xla_device()) + + x = copy.deepcopy(x_expected) + xs.mark_sharding(x, mesh, (('data', 'fsdp'), None)) + + output_expected = model_expected(x_expected) + output = model(x) + self.assertTrue(torch.allclose(output_expected.cpu(), output.cpu())) + + def test_fsdp_v2_multi_slice_error(self): + model = self.SimpleLinear().to(xm.xla_device()) + xs.set_global_mesh(self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor'))) + + with self.assertRaisesRegex(ValueError, "The provided ddp axis is not in the mesh."): + model = FSDPv2(model, extra_data_axis='ddp') + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py index 461d66b8565..6cc7e41ccd4 100644 --- a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py +++ b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py @@ -13,7 +13,7 @@ from torch_xla.distributed.fsdp.wrap import recursive_wrap -def _prepare_spmd_partition_spec(param): +def _prepare_spmd_partition_spec(param, extra_data_axis=None): partition_spec = [None] * len(param.shape) # Skip scalar tensors and it replicated. if len(partition_spec) == 0: @@ -24,6 +24,8 @@ def _prepare_spmd_partition_spec(param): # TODO: should we shard on the maximal dim for param? Then we need # another helper for the output. partition_spec[0] = "fsdp" + if extra_data_axis: + partition_spec[0] = ("fsdp", extra_data_axis) return tuple(partition_spec) @@ -44,10 +46,12 @@ class SpmdFullyShardedDataParallel(nn.Module): def __init__( self, module: nn.Module, + *, mesh: Optional[spmd.Mesh] = None, shard_output: Optional[Callable] = None, auto_wrap_policy: Optional[Callable] = None, auto_wrapper_callable: Optional[Callable] = None, + extra_data_axis: Optional[str] = None, ): if isinstance(module, SpmdFullyShardedDataParallel): raise RuntimeError( @@ -74,6 +78,8 @@ def __init__( ) if "fsdp" not in mesh.axis_names: raise ValueError("The mesh must have an axis named 'fsdp'.") + if extra_data_axis and extra_data_axis not in mesh.axis_names: + raise ValueError(f"The provided {extra_data_axis} axis is not in the mesh.") super().__init__() @@ -131,7 +137,7 @@ def shard_output_impl(output, mesh): ) spmd.mark_sharding(real_output, mesh, - _prepare_spmd_partition_spec(real_output)) + _prepare_spmd_partition_spec(real_output, extra_data_axis)) shard_output = shard_output_impl From aeb59b5d8bf6dd65788b2cf20cde4a9ea0096c33 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Fri, 10 May 2024 01:14:36 +0000 Subject: [PATCH 2/4] Fix linters --- test/spmd/test_fsdp_v2.py | 20 +++++++++++-------- .../spmd_fully_sharded_data_parallel.py | 8 +++++--- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/test/spmd/test_fsdp_v2.py b/test/spmd/test_fsdp_v2.py index 0f8d1a7d954..81ccce13c42 100644 --- a/test/spmd/test_fsdp_v2.py +++ b/test/spmd/test_fsdp_v2.py @@ -142,23 +142,23 @@ def test_fsdp_v2_cpu_model(self): @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") def test_fsdp_v2_multi_slice(self): model = self.SimpleLinear().to(xm.xla_device()) - mesh = self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor')) + mesh = self._get_mesh((2, self.n_devices // 2, 1), None, + ('data', 'fsdp', 'tensor')) model = FSDPv2(model, mesh=mesh, extra_data_axis="data") # Make sure all weights are sharded. annotation = '{devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}' self.assertEqual(annotation, - torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)) + torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)) self.assertEqual(annotation, - torch_xla._XLAC._get_xla_sharding_spec(model.fc2.weight)) + torch_xla._XLAC._get_xla_sharding_spec(model.fc2.weight)) x = torch.randn(16, 128).to(xm.xla_device()) xs.mark_sharding(x, mesh, (('data', 'fsdp'), None)) output = model(x) # Make sure output are sharded. annotation = '{devices=[4,1]0,2,1,3}' - self.assertEqual(annotation, - torch_xla._XLAC._get_xla_sharding_spec(output)) + self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(output)) # Make sure the model can execute without error. xm.mark_step() @@ -169,7 +169,8 @@ def test_fsdp_v2_multi_slice_output_correctness(self): model_expected = self.SimpleLinear().to(xm.xla_device()) model = copy.deepcopy(model_expected) - mesh = self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor')) + mesh = self._get_mesh((2, self.n_devices // 2, 1), None, + ('data', 'fsdp', 'tensor')) model = FSDPv2(model, mesh=mesh, extra_data_axis="data") x_expected = torch.randn(16, 128).to(xm.xla_device()) @@ -183,9 +184,12 @@ def test_fsdp_v2_multi_slice_output_correctness(self): def test_fsdp_v2_multi_slice_error(self): model = self.SimpleLinear().to(xm.xla_device()) - xs.set_global_mesh(self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor'))) + xs.set_global_mesh( + self._get_mesh((2, self.n_devices // 2, 1), None, + ('data', 'fsdp', 'tensor'))) - with self.assertRaisesRegex(ValueError, "The provided ddp axis is not in the mesh."): + with self.assertRaisesRegex(ValueError, + "The provided ddp axis is not in the mesh."): model = FSDPv2(model, extra_data_axis='ddp') diff --git a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py index 6cc7e41ccd4..bb01204508d 100644 --- a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py +++ b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py @@ -79,7 +79,8 @@ def __init__( if "fsdp" not in mesh.axis_names: raise ValueError("The mesh must have an axis named 'fsdp'.") if extra_data_axis and extra_data_axis not in mesh.axis_names: - raise ValueError(f"The provided {extra_data_axis} axis is not in the mesh.") + raise ValueError( + f"The provided {extra_data_axis} axis is not in the mesh.") super().__init__() @@ -136,8 +137,9 @@ def shard_output_impl(output, mesh): f"The output type is not supported: {type(output)}. Please provide your own shard_output callable." ) - spmd.mark_sharding(real_output, mesh, - _prepare_spmd_partition_spec(real_output, extra_data_axis)) + spmd.mark_sharding( + real_output, mesh, + _prepare_spmd_partition_spec(real_output, extra_data_axis)) shard_output = shard_output_impl From 76c7435b2cec973eb127a4e57e6abdf087f9b34d Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Fri, 10 May 2024 01:19:47 +0000 Subject: [PATCH 3/4] Add fsdpv2 test to tpu ci --- test/tpu/run_tests.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 2653d0ed8c2..1658bca8977 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -11,6 +11,7 @@ python3 test/spmd/test_xla_distributed_checkpoint.py python3 test/spmd/test_train_spmd_linear_model.py python3 test/spmd/test_xla_spmd_python_api_interaction.py python3 test/spmd/test_xla_auto_sharding.py +python3 test/spmd/test_fsdp_v2.py XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shape_models.py -v XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shapes.py -v python3 test/test_autocast.py From e1a5d7606b32c1ee12525ffbc569dc734f391da4 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Fri, 10 May 2024 19:39:54 +0000 Subject: [PATCH 4/4] Fix comment --- test/spmd/test_fsdp_v2.py | 3 ++- torch_xla/experimental/spmd_fully_sharded_data_parallel.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/spmd/test_fsdp_v2.py b/test/spmd/test_fsdp_v2.py index 81ccce13c42..08b588c9edc 100644 --- a/test/spmd/test_fsdp_v2.py +++ b/test/spmd/test_fsdp_v2.py @@ -157,7 +157,8 @@ def test_fsdp_v2_multi_slice(self): xs.mark_sharding(x, mesh, (('data', 'fsdp'), None)) output = model(x) # Make sure output are sharded. - annotation = '{devices=[4,1]0,2,1,3}' + annotation = '{devices=[4,1]0,1,2,3}' + self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(x)) self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(output)) # Make sure the model can execute without error. diff --git a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py index bb01204508d..994f7e77dbe 100644 --- a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py +++ b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py @@ -25,7 +25,7 @@ def _prepare_spmd_partition_spec(param, extra_data_axis=None): # another helper for the output. partition_spec[0] = "fsdp" if extra_data_axis: - partition_spec[0] = ("fsdp", extra_data_axis) + partition_spec[0] = (extra_data_axis, "fsdp") return tuple(partition_spec)