diff --git a/test/spmd/test_fsdp_v2.py b/test/spmd/test_fsdp_v2.py index ae997892547e..0f8d1a7d954c 100644 --- a/test/spmd/test_fsdp_v2.py +++ b/test/spmd/test_fsdp_v2.py @@ -26,9 +26,9 @@ def setUpClass(cls): def test_fsdp_v2_basic(self): model = self.SimpleLinear().to(xm.xla_device()) mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor')) - model.fc1 = FSDPv2(model.fc1, mesh) - model.fc2 = FSDPv2(model.fc2, mesh) - model = FSDPv2(model, mesh) + model.fc1 = FSDPv2(model.fc1, mesh=mesh) + model.fc2 = FSDPv2(model.fc2, mesh=mesh) + model = FSDPv2(model, mesh=mesh) # Make sure all weights are sharded. if self.n_devices > 1: @@ -67,9 +67,9 @@ def test_fsdp_v2_output_correctness(self): model = copy.deepcopy(model_expected) mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor')) - model.fc1 = FSDPv2(model.fc1, mesh) - model.fc2 = FSDPv2(model.fc2, mesh) - model = FSDPv2(model, mesh) + model.fc1 = FSDPv2(model.fc1, mesh=mesh) + model.fc2 = FSDPv2(model.fc2, mesh=mesh) + model = FSDPv2(model, mesh=mesh) x_expected = torch.randn(16, 128).to(xm.xla_device()) @@ -87,7 +87,7 @@ def test_fsdp_v2_auto_wrap_basic(self): transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Linear}, ) - model = FSDPv2(model, mesh, auto_wrap_policy=auto_wrap_policy) + model = FSDPv2(model, mesh=mesh, auto_wrap_policy=auto_wrap_policy) self.assertTrue(isinstance(model.fc1, FSDPv2)) self.assertTrue(isinstance(model.fc2, FSDPv2)) @@ -106,7 +106,7 @@ def auto_wrapper_callable(m, *args, **kwargs): model = FSDPv2( model, - mesh, + mesh=mesh, auto_wrap_policy=auto_wrap_policy, auto_wrapper_callable=auto_wrapper_callable) @@ -139,6 +139,55 @@ def test_fsdp_v2_cpu_model(self): self.assertEqual( str(list(model._orig_module.parameters())[0].device), "xla:0") + @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") + def test_fsdp_v2_multi_slice(self): + model = self.SimpleLinear().to(xm.xla_device()) + mesh = self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor')) + model = FSDPv2(model, mesh=mesh, extra_data_axis="data") + + # Make sure all weights are sharded. + annotation = '{devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}' + self.assertEqual(annotation, + torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)) + self.assertEqual(annotation, + torch_xla._XLAC._get_xla_sharding_spec(model.fc2.weight)) + + x = torch.randn(16, 128).to(xm.xla_device()) + xs.mark_sharding(x, mesh, (('data', 'fsdp'), None)) + output = model(x) + # Make sure output are sharded. + annotation = '{devices=[4,1]0,2,1,3}' + self.assertEqual(annotation, + torch_xla._XLAC._get_xla_sharding_spec(output)) + + # Make sure the model can execute without error. + xm.mark_step() + xm.wait_device_ops() + + @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") + def test_fsdp_v2_multi_slice_output_correctness(self): + model_expected = self.SimpleLinear().to(xm.xla_device()) + + model = copy.deepcopy(model_expected) + mesh = self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor')) + model = FSDPv2(model, mesh=mesh, extra_data_axis="data") + + x_expected = torch.randn(16, 128).to(xm.xla_device()) + + x = copy.deepcopy(x_expected) + xs.mark_sharding(x, mesh, (('data', 'fsdp'), None)) + + output_expected = model_expected(x_expected) + output = model(x) + self.assertTrue(torch.allclose(output_expected.cpu(), output.cpu())) + + def test_fsdp_v2_multi_slice_error(self): + model = self.SimpleLinear().to(xm.xla_device()) + xs.set_global_mesh(self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor'))) + + with self.assertRaisesRegex(ValueError, "The provided ddp axis is not in the mesh."): + model = FSDPv2(model, extra_data_axis='ddp') + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py index 461d66b8565c..6cc7e41ccd40 100644 --- a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py +++ b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py @@ -13,7 +13,7 @@ from torch_xla.distributed.fsdp.wrap import recursive_wrap -def _prepare_spmd_partition_spec(param): +def _prepare_spmd_partition_spec(param, extra_data_axis=None): partition_spec = [None] * len(param.shape) # Skip scalar tensors and it replicated. if len(partition_spec) == 0: @@ -24,6 +24,8 @@ def _prepare_spmd_partition_spec(param): # TODO: should we shard on the maximal dim for param? Then we need # another helper for the output. partition_spec[0] = "fsdp" + if extra_data_axis: + partition_spec[0] = ("fsdp", extra_data_axis) return tuple(partition_spec) @@ -44,10 +46,12 @@ class SpmdFullyShardedDataParallel(nn.Module): def __init__( self, module: nn.Module, + *, mesh: Optional[spmd.Mesh] = None, shard_output: Optional[Callable] = None, auto_wrap_policy: Optional[Callable] = None, auto_wrapper_callable: Optional[Callable] = None, + extra_data_axis: Optional[str] = None, ): if isinstance(module, SpmdFullyShardedDataParallel): raise RuntimeError( @@ -74,6 +78,8 @@ def __init__( ) if "fsdp" not in mesh.axis_names: raise ValueError("The mesh must have an axis named 'fsdp'.") + if extra_data_axis and extra_data_axis not in mesh.axis_names: + raise ValueError(f"The provided {extra_data_axis} axis is not in the mesh.") super().__init__() @@ -131,7 +137,7 @@ def shard_output_impl(output, mesh): ) spmd.mark_sharding(real_output, mesh, - _prepare_spmd_partition_spec(real_output)) + _prepare_spmd_partition_spec(real_output, extra_data_axis)) shard_output = shard_output_impl