From d26ab0596d8d1d017c29b93846c42acd652a3dea Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 10 Jan 2025 11:55:49 -0800 Subject: [PATCH] Introduce ManagedDeviceMesh to integrate DeviceMesh with TorchFT Summary: ManagedDeviceMesh allow users to manipulate DeviceMesh with TorchFT ManagedProcessGroup. ghstack-source-id: c42ae2205624402ccfe99fb87c847f6ffb7a1703 Pull Request resolved: https://github.com/pytorch/torchft/pull/56 --- pyproject.toml | 4 +- torchft/fsdp_test.py | 74 ++++++++++ torchft/process_group.py | 262 ++++++++++++++++++++++++++++++++-- torchft/process_group_test.py | 55 ++++++- 4 files changed, 379 insertions(+), 16 deletions(-) create mode 100644 torchft/fsdp_test.py diff --git a/pyproject.toml b/pyproject.toml index a2da5a2..597ba27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,9 @@ dev = [ "pytest", "black", "pyre-check", - "parameterized" + "parameterized", + "expecttest", + "numpy" ] [tool.maturin] diff --git a/torchft/fsdp_test.py b/torchft/fsdp_test.py new file mode 100644 index 0000000..ec4dc92 --- /dev/null +++ b/torchft/fsdp_test.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import multiprocessing +import os +import unittest +from concurrent.futures import ProcessPoolExecutor +from typing import Any, Dict, Tuple +from unittest.mock import Mock + +import torch +import torch.distributed as dist +from torch import nn +from torch._C._distributed_c10d import ( + AllgatherOptions, + AllreduceOptions, + BroadcastOptions, + ReduceOp, + _resolve_process_group, +) +from torch.distributed import ( + ReduceOp, + TCPStore, + Work, + _functional_collectives, + get_world_size, +) +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed.device_mesh import init_device_mesh + +from torchft.manager import Manager +from torchft.process_group import ManagedProcessGroup, ft_init_device_mesh + + +class FSDPTest(unittest.TestCase): + @staticmethod + def _test_fsdp(world_size: int, rank: int) -> None: + torch.cuda.set_device(rank) + + group_size = world_size // 2 + group = rank // group_size + group_rank = rank % group_size + + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(12346 + group) + os.environ["RANK"] = str(group_rank) + os.environ["WORLD_SIZE"] = str(group_size) + + manager = Mock(spec=Manager) + device_mesh = ft_init_device_mesh( + device_type="cuda", + mesh_shape=(2, 2), + mesh_dim_names=("dp_replicate", "dp_shard"), + replicate_dim=0, + manager=manager, + ) + manager.num_participants.return_value = 1 + model = nn.Linear(128, 128).cuda() + batch = torch.randn(4, 128).cuda() + shard_model = fully_shard(model, mesh=device_mesh) + shard_model(batch).mean().backward() + + # pyre-ignore[56]: Pyre was not able to infer the type of argument + @unittest.skipIf(torch.cuda.device_count() < 4, "Not enough GPUs") + def test_fsdp(self) -> None: + multiprocessing.set_start_method("spawn") + with ProcessPoolExecutor(max_workers=4) as executor: + futures = [] + for i in range(4): + future = executor.submit(self._test_fsdp, 4, i) + futures.append(future) diff --git a/torchft/process_group.py b/torchft/process_group.py index 43a7716..34797c7 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -21,7 +21,7 @@ import threading from abc import ABC from datetime import timedelta -from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union import torch import torch.distributed as dist @@ -39,6 +39,7 @@ Store, TCPStore, get_rank, + init_device_mesh, ) from torch.distributed.distributed_c10d import Work, _world from torch.futures import Future @@ -149,17 +150,7 @@ def size(self) -> int: def getBackendName(self) -> str: raise NotImplementedError("not implemented") - def register(self, name: str) -> "ProcessGroup": - """ - Registers the process group with the global registry. This enables usage - with things like functional_collectives which are compilable. - - This should only be called once. - - Args: - name: name must be a unique name for this process group - """ - + def _register(self, name: str) -> str: group_name = f"{self.getBackendName()}:{name}" # This is needed for DeviceMesh and functional collectives to work. @@ -177,6 +168,21 @@ def create_pg( devices = ["cpu"] dist.Backend.register_backend(group_name, create_pg, devices=devices) + return group_name + + def register(self, name: str) -> "ProcessGroup": + """ + Registers the process group with the global registry. This enables usage + with things like functional_collectives which are compilable. + + This should only be called once. + + Args: + name: name must be a unique name for this process group + """ + + group_name = self._register(name) + return dist.new_group( ranks=[dist.get_rank()], backend=group_name, @@ -519,6 +525,9 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: def size(self) -> int: return self._manager.num_participants() + def getBackendName(self) -> str: + return self._manager._pg.getBackendName() + class _BabyWork(Work): def __init__( @@ -730,7 +739,6 @@ def _future_handler(self, future_queue: mp.Queue) -> None: logger.exception(f"got unexpected error in future handler: {e}") def _get_future(self, op_id: int) -> Future[object]: - with self._futures_lock: fut = Future() # pyre-fixme[29]: is not a function self._futures[op_id] = fut @@ -841,3 +849,231 @@ def extend_device_mesh( mesh=mesh.mesh.unsqueeze(dim), mesh_dim_names=tuple(mesh_dim_names), ) + + +class ManagedDeviceMesh(DeviceMesh): + def __init__( + self, + mesh: Optional[DeviceMesh], + mesh_dim_names: Tuple[str, ...], + replicate_pg: ManagedProcessGroup, + replicate_dim: int, + parent: Optional["ManagedDeviceMesh"], + ) -> None: + if mesh is None and parent is None: + raise ValueError( + "ManagedDeviceMesh doesn't support both mesh and parent are None." + ) + self.mesh = mesh + self.mesh_dim_names = mesh_dim_names + self.replicate_pg = replicate_pg + self.replicate_dim = replicate_dim + self.replicate_dim_name: str = mesh_dim_names[replicate_dim] + self.parent = parent + self.flatten_meshes: Dict[str, DeviceMesh] = {} + self.device_type: str + if mesh is not None: + self.device_type = mesh.device_type + else: + assert parent is not None + self.device_type = parent.device_type + self._flatten_mesh_list: Tuple[DeviceMesh, ...] = tuple() + self._thread_id: Optional[int] = None + + def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh: + if isinstance(mesh_dim_names, str): + if mesh_dim_names == self.replicate_dim_name: + return ManagedDeviceMesh( + mesh=None, + mesh_dim_names=(mesh_dim_names,), + replicate_pg=self.replicate_pg, + replicate_dim=0, + parent=self, + ) + elif mesh_dim_names in self.flatten_meshes: + return self.flatten_meshes[mesh_dim_names] + else: + assert self.mesh is not None + return self.mesh[mesh_dim_names] + else: + assert isinstance(mesh_dim_names, tuple) + if self.replicate_dim_name in mesh_dim_names: + assert self.mesh is not None + return self.mesh[mesh_dim_names] + else: + assert self.mesh is not None + return ManagedDeviceMesh( + self.mesh[mesh_dim_names], + mesh_dim_names, + self.replicate_pg, + mesh_dim_names.index(self.replicate_dim_name), + parent=self, + ) + + def _real_mesh_dim(self, mesh_dim: int) -> int: + return mesh_dim - 1 if mesh_dim > self.replicate_dim else mesh_dim + + def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> BaseProcessGroup: + if isinstance(mesh_dim, str): + dim = self.mesh_dim_names.index(mesh_dim) + else: + dim = 0 if mesh_dim is None else int(mesh_dim) + + if mesh_dim is None: + return self.replicate_pg + elif dim == self.replicate_dim: + return self.replicate_pg + else: + assert self.mesh is not None + return self.mesh.get_group(self._real_mesh_dim(dim)) + + def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh": + flatten_mesh = _FlattenDeviceMesh(self) + if mesh_dim_name is None: + raise ValueError("ManagedDeviceMesh._flatten requires `mesh_dim_name`") + if self.parent is None: + self.flatten_meshes[mesh_dim_name] = flatten_mesh + else: + self.parent.flatten_meshes[mesh_dim_name] = flatten_mesh + return flatten_mesh + + def size(self, mesh_dim: Optional[int] = None) -> int: + if mesh_dim is None: + if self.mesh is None: + return self.replicate_pg.size() + else: + assert self.mesh is not None + return self.mesh.size() * self.replicate_pg.size() + elif mesh_dim == self.replicate_dim: + return self.replicate_pg.size() + else: + assert self.mesh is not None + return self.mesh.size(self._real_mesh_dim(mesh_dim)) + + @property + def ndim(self) -> int: + assert self.mesh is not None + return self.mesh.ndim + 1 + + @property + def shape(self) -> Tuple[int, ...]: + assert self.mesh is not None + ret: List[int] = list(self.mesh.shape) + ret.insert(self.replicate_dim, self.replicate_pg.size()) + return tuple(ret) + + def get_rank(self) -> int: + assert self.mesh is not None + return self.mesh.get_rank() + + def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int: + if isinstance(mesh_dim, str): + dim = self.mesh_dim_names.index(mesh_dim) + else: + dim = 0 if mesh_dim is None else int(mesh_dim) + + if mesh_dim is None: + if self.mesh is None: + return get_rank(self.replicate_pg) + + assert self.replicate_dim == 0, "replicate_dim must be the first one" + assert self.mesh is not None + other_dim_size = self.mesh.size() + assert self.mesh is not None + other_dim_rank = self.mesh.get_local_rank() + replicate_pg_rank = get_rank(self.replicate_pg) + return other_dim_size * replicate_pg_rank + other_dim_rank + elif dim == self.replicate_dim: + return get_rank(self.replicate_pg) + else: + assert self.mesh is not None + return self.mesh.get_local_rank(self._real_mesh_dim(dim)) + + def get_coordinate(self) -> Optional[List[int]]: + """ + Return the relative indices of this rank relative to all + dimensions of the mesh. If this rank is not part of the mesh, return None. + """ + assert self.mesh is not None + return self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None + + def get_all_groups(self) -> List[BaseProcessGroup]: + raise NotImplementedError + + +class _FlattenDeviceMesh(DeviceMesh): + def __init__(self, managed_mesh: ManagedDeviceMesh) -> None: + self.managed_mesh = managed_mesh + + def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh: + raise NotImplementedError + + def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> BaseProcessGroup: + raise NotImplementedError + + def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh": + raise NotImplementedError + + def size(self, mesh_dim: Optional[int] = None) -> int: + assert mesh_dim is None + return self.managed_mesh.size() + + @property + def ndim(self) -> int: + raise NotImplementedError + + @property + def shape(self) -> Tuple[int, ...]: + raise NotImplementedError + + def get_rank(self) -> int: + raise NotImplementedError + + def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int: + assert mesh_dim is None + return self.managed_mesh.get_local_rank() + + def get_all_groups(self) -> List[BaseProcessGroup]: + raise NotImplementedError + + +def ft_init_device_mesh( + *, + device_type: str, + mesh_shape: Tuple[int, ...], + mesh_dim_names: Tuple[str, ...], + replicate_dim: int, + manager: "Manager", +) -> "ManagedDeviceMesh": + # We need to mislead DeviceMesh into thinking that replicate_dim has only + # 1 rank. + _mesh_shape = list(mesh_shape) + _mesh_shape.pop(replicate_dim) + _mesh_dim_names = list(mesh_dim_names) + _mesh_dim_names.pop(replicate_dim) + mesh = init_device_mesh( + device_type, + mesh_shape=tuple(_mesh_shape), + mesh_dim_names=tuple(_mesh_dim_names), + ) + + if device_type == "cpu": + pg = ProcessGroupGloo() + elif device_type == "cuda": + pg = ProcessGroupNCCL() + else: + raise ValueError() + + manager._pg = pg + replicate_pg = ManagedProcessGroup(manager) + # We have to use MultiProcessTestCase, otherwise c10d will complain + # the same backend has been registered. + replicate_pg.register(mesh_dim_names[replicate_dim]) + + return ManagedDeviceMesh( + mesh=mesh, + mesh_dim_names=mesh_dim_names, + replicate_pg=replicate_pg, + replicate_dim=replicate_dim, + parent=None, + ) diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index 8d618d1..6a849e7 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -4,10 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import multiprocessing import os -from concurrent.futures import ThreadPoolExecutor +import unittest +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from datetime import timedelta -from typing import Any, Dict, Tuple +from typing import Any, Dict, Tuple, cast from unittest import TestCase, skipUnless from unittest.mock import Mock @@ -45,6 +47,7 @@ _ErrorSwallowingWork, _ManagedWork, extend_device_mesh, + ft_init_device_mesh, ) @@ -261,6 +264,7 @@ def test_device_mesh(self) -> None: pg.configure(store_addr, 0, 1) mesh_2d = extend_device_mesh(mesh_1d, pg) + mesh_2d.get_group("dp") assert mesh_2d.ndim == 2 pg.unregister() @@ -326,3 +330,50 @@ def test_managed_process_group(self) -> None: self.assertEqual(manager.report_error.call_count, 0) self.assertEqual(manager.wrap_future.call_count, 1) + + +class DeviceMeshTest(TestCase): + @staticmethod + def _test_init_device_mesh(world_size: int, rank: int) -> None: + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(12346) + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(4) + + testcase = TestCase() + + manager = Mock(spec=Manager) + # Even though we only have 4 workers, we can still initialize (2, 4) mesh. + # That's because the replicate group is NOT phystically created in the + # real mesh but is virtually added to the mesh via ManagedDeviceMesh. + device_mesh = ft_init_device_mesh( + device_type="cpu", + mesh_shape=(2, world_size), + mesh_dim_names=("dp_replicate", "dp_shard"), + replicate_dim=0, + manager=manager, + ) + + testcase.assertTrue( + isinstance(device_mesh.get_group("dp_replicate"), ManagedProcessGroup) + ) + testcase.assertTrue( + not isinstance(device_mesh.get_group("dp_shard"), ManagedProcessGroup) + ) + replicate_group = device_mesh.get_group("dp_replicate") + testcase.assertEqual( + cast(ManagedProcessGroup, replicate_group)._manager, manager + ) + replicate_mesh = device_mesh["dp_replicate"] + testcase.assertEqual(replicate_mesh.get_group(), replicate_group) + flatten_mesh = device_mesh._flatten("dp") + manager.num_participants.return_value = 1 + testcase.assertEqual(flatten_mesh.size(), world_size) + testcase.assertEqual(flatten_mesh.get_local_rank(), dist.get_rank()) + + def test_init_device_mesh(self) -> None: + with ProcessPoolExecutor(max_workers=4) as executor: + futures = [] + for i in range(4): + future = executor.submit(self._test_init_device_mesh, 4, i) + futures.append(future)