Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
alanwaketan committed May 10, 2024
1 parent 887d344 commit 6e26935
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 10 deletions.
65 changes: 57 additions & 8 deletions test/spmd/test_fsdp_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def setUpClass(cls):
def test_fsdp_v2_basic(self):
model = self.SimpleLinear().to(xm.xla_device())
mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor'))
model.fc1 = FSDPv2(model.fc1, mesh)
model.fc2 = FSDPv2(model.fc2, mesh)
model = FSDPv2(model, mesh)
model.fc1 = FSDPv2(model.fc1, mesh=mesh)
model.fc2 = FSDPv2(model.fc2, mesh=mesh)
model = FSDPv2(model, mesh=mesh)

# Make sure all weights are sharded.
if self.n_devices > 1:
Expand Down Expand Up @@ -67,9 +67,9 @@ def test_fsdp_v2_output_correctness(self):

model = copy.deepcopy(model_expected)
mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor'))
model.fc1 = FSDPv2(model.fc1, mesh)
model.fc2 = FSDPv2(model.fc2, mesh)
model = FSDPv2(model, mesh)
model.fc1 = FSDPv2(model.fc1, mesh=mesh)
model.fc2 = FSDPv2(model.fc2, mesh=mesh)
model = FSDPv2(model, mesh=mesh)

x_expected = torch.randn(16, 128).to(xm.xla_device())

Expand All @@ -87,7 +87,7 @@ def test_fsdp_v2_auto_wrap_basic(self):
transformer_auto_wrap_policy,
transformer_layer_cls={torch.nn.Linear},
)
model = FSDPv2(model, mesh, auto_wrap_policy=auto_wrap_policy)
model = FSDPv2(model, mesh=mesh, auto_wrap_policy=auto_wrap_policy)

self.assertTrue(isinstance(model.fc1, FSDPv2))
self.assertTrue(isinstance(model.fc2, FSDPv2))
Expand All @@ -106,7 +106,7 @@ def auto_wrapper_callable(m, *args, **kwargs):

model = FSDPv2(
model,
mesh,
mesh=mesh,
auto_wrap_policy=auto_wrap_policy,
auto_wrapper_callable=auto_wrapper_callable)

Expand Down Expand Up @@ -139,6 +139,55 @@ def test_fsdp_v2_cpu_model(self):
self.assertEqual(
str(list(model._orig_module.parameters())[0].device), "xla:0")

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_fsdp_v2_multi_slice(self):
model = self.SimpleLinear().to(xm.xla_device())
mesh = self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor'))
model = FSDPv2(model, mesh=mesh, extra_data_axis="data")

# Make sure all weights are sharded.
annotation = '{devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}'
self.assertEqual(annotation,
torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight))
self.assertEqual(annotation,
torch_xla._XLAC._get_xla_sharding_spec(model.fc2.weight))

x = torch.randn(16, 128).to(xm.xla_device())
xs.mark_sharding(x, mesh, (('data', 'fsdp'), None))
output = model(x)
# Make sure output are sharded.
annotation = '{devices=[4,1]0,2,1,3}'
self.assertEqual(annotation,
torch_xla._XLAC._get_xla_sharding_spec(output))

# Make sure the model can execute without error.
xm.mark_step()
xm.wait_device_ops()

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_fsdp_v2_multi_slice_output_correctness(self):
model_expected = self.SimpleLinear().to(xm.xla_device())

model = copy.deepcopy(model_expected)
mesh = self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor'))
model = FSDPv2(model, mesh=mesh, extra_data_axis="data")

x_expected = torch.randn(16, 128).to(xm.xla_device())

x = copy.deepcopy(x_expected)
xs.mark_sharding(x, mesh, (('data', 'fsdp'), None))

output_expected = model_expected(x_expected)
output = model(x)
self.assertTrue(torch.allclose(output_expected.cpu(), output.cpu()))

def test_fsdp_v2_multi_slice_error(self):
model = self.SimpleLinear().to(xm.xla_device())
xs.set_global_mesh(self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor')))

with self.assertRaisesRegex(ValueError, "The provided ddp axis is not in the mesh."):
model = FSDPv2(model, extra_data_axis='ddp')


if __name__ == '__main__':
test = unittest.main()
Expand Down
10 changes: 8 additions & 2 deletions torch_xla/experimental/spmd_fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch_xla.distributed.fsdp.wrap import recursive_wrap


def _prepare_spmd_partition_spec(param):
def _prepare_spmd_partition_spec(param, extra_data_axis=None):
partition_spec = [None] * len(param.shape)
# Skip scalar tensors and it replicated.
if len(partition_spec) == 0:
Expand All @@ -24,6 +24,8 @@ def _prepare_spmd_partition_spec(param):
# TODO: should we shard on the maximal dim for param? Then we need
# another helper for the output.
partition_spec[0] = "fsdp"
if extra_data_axis:
partition_spec[0] = ("fsdp", extra_data_axis)
return tuple(partition_spec)


Expand All @@ -44,10 +46,12 @@ class SpmdFullyShardedDataParallel(nn.Module):
def __init__(
self,
module: nn.Module,
*,
mesh: Optional[spmd.Mesh] = None,
shard_output: Optional[Callable] = None,
auto_wrap_policy: Optional[Callable] = None,
auto_wrapper_callable: Optional[Callable] = None,
extra_data_axis: Optional[str] = None,
):
if isinstance(module, SpmdFullyShardedDataParallel):
raise RuntimeError(
Expand All @@ -74,6 +78,8 @@ def __init__(
)
if "fsdp" not in mesh.axis_names:
raise ValueError("The mesh must have an axis named 'fsdp'.")
if extra_data_axis and extra_data_axis not in mesh.axis_names:
raise ValueError(f"The provided {extra_data_axis} axis is not in the mesh.")

super().__init__()

Expand Down Expand Up @@ -131,7 +137,7 @@ def shard_output_impl(output, mesh):
)

spmd.mark_sharding(real_output, mesh,
_prepare_spmd_partition_spec(real_output))
_prepare_spmd_partition_spec(real_output, extra_data_axis))

shard_output = shard_output_impl

Expand Down

0 comments on commit 6e26935

Please sign in to comment.