From bc08aac77ffa914d4824c14c9529ef33e122f3f2 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 9 Jan 2025 12:00:16 -0800 Subject: [PATCH] Introduce ManagedDeviceMesh to integrate DeviceMesh with TorchFT Summary: ManagedDeviceMesh allow users to manipulate DeviceMesh with TorchFT ManagedProcessGroup. ghstack-source-id: ace0838d729c7ecdd3720fb9037185c83d9a289a Pull Request resolved: https://github.com/pytorch-labs/torchft/pull/56 --- pyproject.toml | 4 +- torchft/fsdp_test.py | 70 +++++++++ torchft/process_group.py | 263 ++++++++++++++++++++++++++++++++-- torchft/process_group_test.py | 47 ++++++ 4 files changed, 370 insertions(+), 14 deletions(-) create mode 100644 torchft/fsdp_test.py diff --git a/pyproject.toml b/pyproject.toml index a2da5a2..597ba27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,9 @@ dev = [ "pytest", "black", "pyre-check", - "parameterized" + "parameterized", + "expecttest", + "numpy" ] [tool.maturin] diff --git a/torchft/fsdp_test.py b/torchft/fsdp_test.py new file mode 100644 index 0000000..f4337f2 --- /dev/null +++ b/torchft/fsdp_test.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, Tuple +from unittest import TestCase, skipUnless +from unittest.mock import Mock + +import torch +import torch.distributed as dist +from torch import nn +from torch._C._distributed_c10d import ( + AllgatherOptions, + AllreduceOptions, + BroadcastOptions, + ReduceOp, + _resolve_process_group, +) +from torch.distributed import ( + ReduceOp, + TCPStore, + Work, + _functional_collectives, + get_world_size, +) +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed.device_mesh import init_device_mesh +from torch.testing._internal.common_distributed import MultiProcessTestCase + +from torchft.manager import Manager +from torchft.process_group import ManagedProcessGroup, ft_init_device_mesh + + +class FSDPTest(MultiProcessTestCase): + @property + def world_size(self) -> int: + return 4 + + def setUp(self) -> None: + super().setUp() + os.environ["TORCH_NCCL_DESYNC_DEBUG"] = "0" + self._spawn_processes() + + def test_fsdp(self) -> None: + group_size = self.world_size // 2 + group = self.rank // group_size + group_rank = self.rank % group_size + + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(12346 + group) + os.environ["RANK"] = str(group_rank) + os.environ["WORLD_SIZE"] = str(group_size) + + manager = Mock(spec=Manager) + device_mesh = ft_init_device_mesh( + device_type="cuda", + mesh_shape=(2, 2), + mesh_dim_names=("dp_replicate", "dp_shard"), + replicate_dim=0, + manager=manager, + ) + manager.num_participants.return_value = 1 + model = nn.Linear(128, 128).cuda() + batch = torch.randn(4, 128).cuda() + shard_model = fully_shard(model, mesh=device_mesh) + shard_model(batch).mean().backward() diff --git a/torchft/process_group.py b/torchft/process_group.py index 735aa4e..693b23b 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -20,7 +20,7 @@ import threading from abc import ABC from datetime import timedelta -from typing import TYPE_CHECKING, Dict, List, Optional, Type +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union import torch import torch.distributed as dist @@ -38,6 +38,7 @@ Store, TCPStore, get_rank, + init_device_mesh, ) from torch.distributed.distributed_c10d import Work, _world from torch.futures import Future @@ -130,17 +131,7 @@ def size(self) -> int: def getBackendName(self) -> str: raise NotImplementedError("not implemented") - def register(self, name: str) -> "ProcessGroup": - """ - Registers the process group with the global registry. This enables usage - with things like functional_collectives which are compilable. - - This should only be called once. - - Args: - name: name must be a unique name for this process group - """ - + def _register(self, name: str) -> str: group_name = f"{self.getBackendName()}:{name}" # This is needed for DeviceMesh and functional collectives to work. @@ -158,6 +149,21 @@ def create_pg( devices = ["cpu"] dist.Backend.register_backend(group_name, create_pg, devices=devices) + return group_name + + def register(self, name: str) -> "ProcessGroup": + """ + Registers the process group with the global registry. This enables usage + with things like functional_collectives which are compilable. + + This should only be called once. + + Args: + name: name must be a unique name for this process group + """ + + group_name = self._register(name) + return dist.new_group( ranks=[dist.get_rank()], backend=group_name, @@ -496,6 +502,9 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: def size(self) -> int: return self._manager.num_participants() + def getBackendName(self) -> str: + return self._manager._pg.getBackendName() + class _BabyWork(Work): def __init__( @@ -689,7 +698,6 @@ def _future_handler(self, future_queue: mp.Queue) -> None: logger.exception(f"got unexpected error in future handler: {e}") def _get_future(self, op_id: int) -> Future[object]: - with self._futures_lock: fut = Future() # pyre-fixme[29]: is not a function self._futures[op_id] = fut @@ -797,3 +805,232 @@ def extend_device_mesh( mesh=mesh.mesh.unsqueeze(dim), mesh_dim_names=tuple(mesh_dim_names), ) + + +class ManagedDeviceMesh(DeviceMesh): + def __init__( + self, + mesh: Optional[DeviceMesh], + mesh_dim_names: Tuple[str, ...], + replicate_pg: ManagedProcessGroup, + replicate_dim: int, + parent: Optional["ManagedDeviceMesh"], + ) -> None: + if mesh is None and parent is not None: + raise ValueError( + "ManagedDeviceMesh doesn't support both mesh and parent are None." + ) + self.mesh = mesh + self.mesh_dim_names = mesh_dim_names + self.replicate_pg = replicate_pg + self.replicate_dim = replicate_dim + self.replicate_dim_name: str = mesh_dim_names[replicate_dim] + self.parent = parent + self.flatten_meshes: Dict[str, DeviceMesh] = {} + self.device_type: str + if mesh is not None: + self.device_type = mesh.device_type + else: + assert parent is not None + self.device_type = parent.device_type + self._flatten_mesh_list: Tuple[DeviceMesh, ...] = tuple() + self._thread_id: Optional[int] = None + + def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh: + if isinstance(mesh_dim_names, str): + if mesh_dim_names == self.replicate_dim_name: + return ManagedDeviceMesh( + mesh=None, + mesh_dim_names=(mesh_dim_names,), + replicate_pg=self.replicate_pg, + replicate_dim=0, + parent=self, + ) + elif mesh_dim_names in self.flatten_meshes: + return self.flatten_meshes[mesh_dim_names] + else: + assert self.mesh is not None + return self.mesh[mesh_dim_names] + else: + assert isinstance(mesh_dim_names, tuple) + if self.replicate_dim_name in mesh_dim_names: + assert self.mesh is not None + return self.mesh[mesh_dim_names] + else: + assert self.mesh is not None + return ManagedDeviceMesh( + self.mesh[mesh_dim_names], + mesh_dim_names, + self.replicate_pg, + mesh_dim_names.index(self.replicate_dim_name), + parent=self, + ) + + def _real_mesh_dim(self, mesh_dim: int) -> int: + return mesh_dim - 1 if mesh_dim > self.replicate_dim else mesh_dim + + def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> BaseProcessGroup: + if isinstance(mesh_dim, str): + dim = self.mesh_dim_names.index(mesh_dim) + else: + dim = 0 if mesh_dim is None else int(mesh_dim) + + if mesh_dim is None: + assert self.mesh is not None + return self.replicate_pg + elif dim == self.replicate_dim: + return self.replicate_pg + else: + assert self.mesh is not None + return self.mesh.get_group(self._real_mesh_dim(dim)) + + def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh": + flatten_mesh = _FlattenDeviceMesh(self) + if mesh_dim_name is None: + raise ValueError("ManagedDeviceMesh._flatten requires `mesh_dim_name`") + if self.parent is None: + self.flatten_meshes[mesh_dim_name] = flatten_mesh + else: + self.parent.flatten_meshes[mesh_dim_name] = flatten_mesh + return flatten_mesh + + def size(self, mesh_dim: Optional[int] = None) -> int: + if mesh_dim is None: + if self.mesh is None: + return self.replicate_pg.size() + else: + assert self.mesh is not None + return self.mesh.size() * self.replicate_pg.size() + elif mesh_dim == self.replicate_dim: + return self.replicate_pg.size() + else: + assert self.mesh is not None + return self.mesh.size(self._real_mesh_dim(mesh_dim)) + + @property + def ndim(self) -> int: + assert self.mesh is not None + return self.mesh.ndim + 1 + + @property + def shape(self) -> Tuple[int, ...]: + assert self.mesh is not None + ret: List[int] = list(self.mesh.shape) + ret.insert(self.replicate_dim, self.replicate_pg.size()) + return tuple(ret) + + def get_rank(self) -> int: + assert self.mesh is not None + return self.mesh.get_rank() + + def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int: + if isinstance(mesh_dim, str): + dim = self.mesh_dim_names.index(mesh_dim) + else: + dim = 0 if mesh_dim is None else int(mesh_dim) + + if mesh_dim is None: + if self.mesh is None: + return get_rank(self.replicate_pg) + + assert self.replicate_dim == 0, "replicate_dim must be the first one" + assert self.mesh is not None + other_dim_size = self.mesh.size() + assert self.mesh is not None + other_dim_rank = self.mesh.get_local_rank() + replicate_pg_rank = get_rank(self.replicate_pg) + return other_dim_size * replicate_pg_rank + other_dim_rank + elif dim == self.replicate_dim: + return get_rank(self.replicate_pg) + else: + assert self.mesh is not None + return self.mesh.get_local_rank(self._real_mesh_dim(dim)) + + def get_coordinate(self) -> Optional[List[int]]: + """ + Return the relative indices of this rank relative to all + dimensions of the mesh. If this rank is not part of the mesh, return None. + """ + assert self.mesh is not None + return self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None + + def get_all_groups(self) -> List[BaseProcessGroup]: + raise NotImplementedError + + +class _FlattenDeviceMesh(DeviceMesh): + def __init__(self, managed_mesh: ManagedDeviceMesh) -> None: + self.managed_mesh = managed_mesh + + def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh: + raise NotImplementedError + + def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> BaseProcessGroup: + raise NotImplementedError + + def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh": + raise NotImplementedError + + def size(self, mesh_dim: Optional[int] = None) -> int: + assert mesh_dim is None + return self.managed_mesh.size() + + @property + def ndim(self) -> int: + raise NotImplementedError + + @property + def shape(self) -> Tuple[int, ...]: + raise NotImplementedError + + def get_rank(self) -> int: + raise NotImplementedError + + def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int: + assert mesh_dim is None + return self.managed_mesh.get_local_rank() + + def get_all_groups(self) -> List[BaseProcessGroup]: + raise NotImplementedError + + +def ft_init_device_mesh( + *, + device_type: str, + mesh_shape: Tuple[int, ...], + mesh_dim_names: Tuple[str, ...], + replicate_dim: int, + manager: "Manager", +) -> "ManagedDeviceMesh": + # We need to mislead DeviceMesh into thinking that replicate_dim has only + # 1 rank. + _mesh_shape = list(mesh_shape) + _mesh_shape.pop(replicate_dim) + _mesh_dim_names = list(mesh_dim_names) + _mesh_dim_names.pop(replicate_dim) + mesh = init_device_mesh( + device_type, + mesh_shape=tuple(_mesh_shape), + mesh_dim_names=tuple(_mesh_dim_names), + ) + + if device_type == "cpu": + pg = ProcessGroupGloo() + elif device_type == "cuda": + pg = ProcessGroupNCCL() + else: + raise ValueError() + + manager._pg = pg + replicate_pg = ManagedProcessGroup(manager) + # We have to use MultiProcessTestCase, otherwise c10d will complain + # the same backend has been registered. + replicate_pg.register(mesh_dim_names[replicate_dim]) + + return ManagedDeviceMesh( + mesh=mesh, + mesh_dim_names=mesh_dim_names, + replicate_pg=replicate_pg, + replicate_dim=replicate_dim, + parent=None, + ) diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index 44e770d..54523b1 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -28,6 +28,7 @@ get_world_size, ) from torch.distributed.device_mesh import init_device_mesh +from torch.testing._internal.common_distributed import MultiProcessTestCase from torchft.manager import Manager from torchft.process_group import ( @@ -44,6 +45,7 @@ _ErrorSwallowingWork, _ManagedWork, extend_device_mesh, + ft_init_device_mesh, ) @@ -234,6 +236,7 @@ def test_device_mesh(self) -> None: pg.configure(store_addr, 0, 1) mesh_2d = extend_device_mesh(mesh_1d, pg) + mesh_2d.get_group("dp") assert mesh_2d.ndim == 2 pg.unregister() @@ -299,3 +302,47 @@ def test_managed_process_group(self) -> None: self.assertEqual(manager.report_error.call_count, 0) self.assertEqual(manager.wrap_future.call_count, 1) + + +class DeviceMeshTest(MultiProcessTestCase): + @property + def world_size(self) -> int: + return 4 + + def setUp(self) -> None: + super().setUp() + os.environ["TORCH_NCCL_DESYNC_DEBUG"] = "0" + self._spawn_processes() + + def test_init_device_mesh(self) -> None: + os.environ["MASTER_PORT"] = str(12346) + os.environ["RANK"] = str(self.rank) + os.environ["WORLD_SIZE"] = str(4) + + manager = Mock(spec=Manager) + # Even though we only have 4 workers, we can still initialize (2, 4) mesh. + # That's because the replicate group is NOT phystically created in the + # real mesh but is virtually added to the mesh via ManagedDeviceMesh. + device_mesh = ft_init_device_mesh( + device_type="cpu", + mesh_shape=(2, self.world_size), + mesh_dim_names=("dp_replicate", "dp_shard"), + replicate_dim=0, + manager=manager, + ) + + self.assertTrue( + isinstance(device_mesh.get_group("dp_replicate"), ManagedProcessGroup) + ) + self.assertTrue( + not isinstance(device_mesh.get_group("dp_shard"), ManagedProcessGroup) + ) + replicate_group = device_mesh.get_group("dp_replicate") + # pyre-ignore[16] + self.assertEqual(replicate_group._manager, manager) + replicate_mesh = device_mesh["dp_replicate"] + self.assertEqual(replicate_mesh.get_group(), replicate_group) + flatten_mesh = device_mesh._flatten("dp") + manager.num_participants.return_value = 1 + self.assertEqual(flatten_mesh.size(), self.world_size) + self.assertEqual(flatten_mesh.get_local_rank(), dist.get_rank())