From 6f0b61e5d782913a0fc7743812f2a8e522189111 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Fri, 10 May 2024 17:35:06 -0700 Subject: [PATCH] [FSDPv2] Support MultiSlice (#7044) Summary: This pull request adds the multi-slice support for FSDPv2. Basically, the default setup is to use the dcn axis as the data axis, and it means we only do data parallel over multi-slices. In the future, we could also support FSDP over mutli-slices. Test Plan: PJRT_DEVICE=TPU python test/spmd/test_fsdp_v2.py --- test/spmd/test_fsdp_v2.py | 70 ++++++++++++++++--- test/tpu/run_tests.sh | 1 + .../spmd_fully_sharded_data_parallel.py | 14 +++- 3 files changed, 74 insertions(+), 11 deletions(-) diff --git a/test/spmd/test_fsdp_v2.py b/test/spmd/test_fsdp_v2.py index ae997892547..08b588c9edc 100644 --- a/test/spmd/test_fsdp_v2.py +++ b/test/spmd/test_fsdp_v2.py @@ -26,9 +26,9 @@ def setUpClass(cls): def test_fsdp_v2_basic(self): model = self.SimpleLinear().to(xm.xla_device()) mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor')) - model.fc1 = FSDPv2(model.fc1, mesh) - model.fc2 = FSDPv2(model.fc2, mesh) - model = FSDPv2(model, mesh) + model.fc1 = FSDPv2(model.fc1, mesh=mesh) + model.fc2 = FSDPv2(model.fc2, mesh=mesh) + model = FSDPv2(model, mesh=mesh) # Make sure all weights are sharded. if self.n_devices > 1: @@ -67,9 +67,9 @@ def test_fsdp_v2_output_correctness(self): model = copy.deepcopy(model_expected) mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor')) - model.fc1 = FSDPv2(model.fc1, mesh) - model.fc2 = FSDPv2(model.fc2, mesh) - model = FSDPv2(model, mesh) + model.fc1 = FSDPv2(model.fc1, mesh=mesh) + model.fc2 = FSDPv2(model.fc2, mesh=mesh) + model = FSDPv2(model, mesh=mesh) x_expected = torch.randn(16, 128).to(xm.xla_device()) @@ -87,7 +87,7 @@ def test_fsdp_v2_auto_wrap_basic(self): transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Linear}, ) - model = FSDPv2(model, mesh, auto_wrap_policy=auto_wrap_policy) + model = FSDPv2(model, mesh=mesh, auto_wrap_policy=auto_wrap_policy) self.assertTrue(isinstance(model.fc1, FSDPv2)) self.assertTrue(isinstance(model.fc2, FSDPv2)) @@ -106,7 +106,7 @@ def auto_wrapper_callable(m, *args, **kwargs): model = FSDPv2( model, - mesh, + mesh=mesh, auto_wrap_policy=auto_wrap_policy, auto_wrapper_callable=auto_wrapper_callable) @@ -139,6 +139,60 @@ def test_fsdp_v2_cpu_model(self): self.assertEqual( str(list(model._orig_module.parameters())[0].device), "xla:0") + @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") + def test_fsdp_v2_multi_slice(self): + model = self.SimpleLinear().to(xm.xla_device()) + mesh = self._get_mesh((2, self.n_devices // 2, 1), None, + ('data', 'fsdp', 'tensor')) + model = FSDPv2(model, mesh=mesh, extra_data_axis="data") + + # Make sure all weights are sharded. + annotation = '{devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}' + self.assertEqual(annotation, + torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)) + self.assertEqual(annotation, + torch_xla._XLAC._get_xla_sharding_spec(model.fc2.weight)) + + x = torch.randn(16, 128).to(xm.xla_device()) + xs.mark_sharding(x, mesh, (('data', 'fsdp'), None)) + output = model(x) + # Make sure output are sharded. + annotation = '{devices=[4,1]0,1,2,3}' + self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(x)) + self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(output)) + + # Make sure the model can execute without error. + xm.mark_step() + xm.wait_device_ops() + + @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") + def test_fsdp_v2_multi_slice_output_correctness(self): + model_expected = self.SimpleLinear().to(xm.xla_device()) + + model = copy.deepcopy(model_expected) + mesh = self._get_mesh((2, self.n_devices // 2, 1), None, + ('data', 'fsdp', 'tensor')) + model = FSDPv2(model, mesh=mesh, extra_data_axis="data") + + x_expected = torch.randn(16, 128).to(xm.xla_device()) + + x = copy.deepcopy(x_expected) + xs.mark_sharding(x, mesh, (('data', 'fsdp'), None)) + + output_expected = model_expected(x_expected) + output = model(x) + self.assertTrue(torch.allclose(output_expected.cpu(), output.cpu())) + + def test_fsdp_v2_multi_slice_error(self): + model = self.SimpleLinear().to(xm.xla_device()) + xs.set_global_mesh( + self._get_mesh((2, self.n_devices // 2, 1), None, + ('data', 'fsdp', 'tensor'))) + + with self.assertRaisesRegex(ValueError, + "The provided ddp axis is not in the mesh."): + model = FSDPv2(model, extra_data_axis='ddp') + if __name__ == '__main__': test = unittest.main() diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 2653d0ed8c2..1658bca8977 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -11,6 +11,7 @@ python3 test/spmd/test_xla_distributed_checkpoint.py python3 test/spmd/test_train_spmd_linear_model.py python3 test/spmd/test_xla_spmd_python_api_interaction.py python3 test/spmd/test_xla_auto_sharding.py +python3 test/spmd/test_fsdp_v2.py XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shape_models.py -v XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shapes.py -v python3 test/test_autocast.py diff --git a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py index 461d66b8565..994f7e77dbe 100644 --- a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py +++ b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py @@ -13,7 +13,7 @@ from torch_xla.distributed.fsdp.wrap import recursive_wrap -def _prepare_spmd_partition_spec(param): +def _prepare_spmd_partition_spec(param, extra_data_axis=None): partition_spec = [None] * len(param.shape) # Skip scalar tensors and it replicated. if len(partition_spec) == 0: @@ -24,6 +24,8 @@ def _prepare_spmd_partition_spec(param): # TODO: should we shard on the maximal dim for param? Then we need # another helper for the output. partition_spec[0] = "fsdp" + if extra_data_axis: + partition_spec[0] = (extra_data_axis, "fsdp") return tuple(partition_spec) @@ -44,10 +46,12 @@ class SpmdFullyShardedDataParallel(nn.Module): def __init__( self, module: nn.Module, + *, mesh: Optional[spmd.Mesh] = None, shard_output: Optional[Callable] = None, auto_wrap_policy: Optional[Callable] = None, auto_wrapper_callable: Optional[Callable] = None, + extra_data_axis: Optional[str] = None, ): if isinstance(module, SpmdFullyShardedDataParallel): raise RuntimeError( @@ -74,6 +78,9 @@ def __init__( ) if "fsdp" not in mesh.axis_names: raise ValueError("The mesh must have an axis named 'fsdp'.") + if extra_data_axis and extra_data_axis not in mesh.axis_names: + raise ValueError( + f"The provided {extra_data_axis} axis is not in the mesh.") super().__init__() @@ -130,8 +137,9 @@ def shard_output_impl(output, mesh): f"The output type is not supported: {type(output)}. Please provide your own shard_output callable." ) - spmd.mark_sharding(real_output, mesh, - _prepare_spmd_partition_spec(real_output)) + spmd.mark_sharding( + real_output, mesh, + _prepare_spmd_partition_spec(real_output, extra_data_axis)) shard_output = shard_output_impl