Skip to content

Commit

Permalink
Introduce ManagedDeviceMesh to integrate DeviceMesh with TorchFT
Browse files Browse the repository at this point in the history
Summary:
ManagedDeviceMesh allow users to manipulate DeviceMesh with TorchFT ManagedProcessGroup.

ghstack-source-id: 321d2f2f5ff2cf9bc16622623b2d80eb95db33cf
Pull Request resolved: #56
  • Loading branch information
fegin committed Jan 8, 2025
1 parent f31d3b1 commit 3e3a5ce
Show file tree
Hide file tree
Showing 3 changed files with 366 additions and 13 deletions.
71 changes: 71 additions & 0 deletions torchft/fsdp_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, Tuple
from unittest import TestCase, skipUnless
from unittest.mock import Mock

import torch
import torch.distributed as dist
from torch import nn
from torch._C._distributed_c10d import (
AllgatherOptions,
AllreduceOptions,
BroadcastOptions,
ReduceOp,
_resolve_process_group,
)
from torch.distributed import (
ReduceOp,
TCPStore,
Work,
_functional_collectives,
get_world_size,
)
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import init_device_mesh
from torch.testing._internal.common_distributed import MultiProcessTestCase

from torchft.manager import Manager
from torchft.process_group import ManagedProcessGroup, ft_init_device_mesh


class FSDPTest(MultiProcessTestCase):
@property
def world_size(self) -> int:
return 4

def setUp(self) -> None:
super().setUp()
os.environ["TORCH_NCCL_DESYNC_DEBUG"] = "0"
self._spawn_processes()

def test_fsdp(self) -> None:
group_size = self.world_size // 2
# pyre-ignore[16]
group = self.rank // group_size
group_rank = self.rank % group_size

os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(12346 + group)
os.environ["RANK"] = str(group_rank)
os.environ["WORLD_SIZE"] = str(group_size)

manager = Mock(spec=Manager)
device_mesh = ft_init_device_mesh(
device_type="cuda",
mesh_shape=(2, 2),
mesh_dim_names=("dp_replicate", "dp_shard"),
replicate_dim=0,
manager=manager,
)
manager.num_participants.return_value = 1
model = nn.Linear(128, 128).cuda()
batch = torch.randn(4, 128).cuda()
shard_model = fully_shard(model, mesh=device_mesh)
shard_model(batch).mean().backward()
259 changes: 246 additions & 13 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import threading
from abc import ABC
from datetime import timedelta
from typing import TYPE_CHECKING, Dict, List, Optional, Type
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union

import torch
import torch.distributed as dist
Expand All @@ -38,6 +38,7 @@
Store,
TCPStore,
get_rank,
init_device_mesh,
)
from torch.distributed.distributed_c10d import Work, _world
from torch.futures import Future
Expand Down Expand Up @@ -130,17 +131,7 @@ def size(self) -> int:
def getBackendName(self) -> str:
raise NotImplementedError("not implemented")

def register(self, name: str) -> "ProcessGroup":
"""
Registers the process group with the global registry. This enables usage
with things like functional_collectives which are compilable.
This should only be called once.
Args:
name: name must be a unique name for this process group
"""

def _register(self, name: str) -> str:
group_name = f"{self.getBackendName()}:{name}"

# This is needed for DeviceMesh and functional collectives to work.
Expand All @@ -158,6 +149,21 @@ def create_pg(
devices = ["cpu"]
dist.Backend.register_backend(group_name, create_pg, devices=devices)

return group_name

def register(self, name: str) -> "ProcessGroup":
"""
Registers the process group with the global registry. This enables usage
with things like functional_collectives which are compilable.
This should only be called once.
Args:
name: name must be a unique name for this process group
"""

group_name = self._register(name)

return dist.new_group(
ranks=[dist.get_rank()],
backend=group_name,
Expand Down Expand Up @@ -496,6 +502,9 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
def size(self) -> int:
return self._manager.num_participants()

def getBackendName(self) -> str:
return self._manager._pg.getBackendName()


class _BabyWork(Work):
def __init__(
Expand Down Expand Up @@ -689,7 +698,6 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
logger.exception(f"got unexpected error in future handler: {e}")

def _get_future(self, op_id: int) -> Future[object]:

with self._futures_lock:
fut = Future() # pyre-fixme[29]: is not a function
self._futures[op_id] = fut
Expand Down Expand Up @@ -797,3 +805,228 @@ def extend_device_mesh(
mesh=mesh.mesh.unsqueeze(dim),
mesh_dim_names=tuple(mesh_dim_names),
)


class ManagedDeviceMesh(DeviceMesh):
def __init__(
self,
mesh: Optional[DeviceMesh],
mesh_dim_names: Tuple[str, ...],
replicate_pg: ManagedProcessGroup,
replicate_dim: int,
parent: Optional["ManagedDeviceMesh"],
) -> None:
if mesh is None and parent is not None:
raise ValueError(
"ManagedDeviceMesh doesn't support both mesh and parent are None."
)
self.mesh = mesh
self.mesh_dim_names = mesh_dim_names
self.replicate_pg = replicate_pg
self.replicate_dim = replicate_dim
self.replicate_dim_name: str = mesh_dim_names[replicate_dim]
self.parent = parent
self.flatten_meshes: Dict[str, DeviceMesh] = {}
self.device_type: str
if mesh is not None:
self.device_type = mesh.device_type
else:
assert parent is not None
self.device_type = parent.device_type
self._flatten_mesh_list: Tuple[DeviceMesh, ...] = tuple()
self._thread_id: Optional[int] = None

def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
if isinstance(mesh_dim_names, str):
if mesh_dim_names == self.replicate_dim_name:
return ManagedDeviceMesh(
mesh=None,
mesh_dim_names=(mesh_dim_names,),
replicate_pg=self.replicate_pg,
replicate_dim=0,
parent=self,
)
elif mesh_dim_names in self.flatten_meshes:
return self.flatten_meshes[mesh_dim_names]
else:
assert self.mesh is not None
return self.mesh[mesh_dim_names]
else:
assert isinstance(mesh_dim_names, tuple)
if self.replicate_dim_name in mesh_dim_names:
assert self.mesh is not None
return self.mesh[mesh_dim_names]
else:
return ManagedDeviceMesh(
self.mesh[mesh_dim_names],
mesh_dim_names,
self.replicate_pg,
mesh_dim_names.index(self.replicate_dim_name),
parent=self,
)

def _real_mesh_dim(self, mesh_dim: int) -> int:
return mesh_dim - 1 if mesh_dim > self.replicate_dim else mesh_dim

def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> BaseProcessGroup:
if isinstance(mesh_dim, str):
dim = self.mesh_dim_names.index(mesh_dim)
else:
dim = 0 if mesh_dim is None else int(mesh_dim)

if mesh_dim is None:
assert self.mesh is not None
return self.replicate_pg
elif dim == self.replicate_dim:
return self.replicate_pg
else:
assert self.mesh is not None
return self.mesh.get_group(self._real_mesh_dim(dim))

def _flatten(self, mesh_dim_name: str) -> "DeviceMesh":
flatten_mesh = _FlattenDeviceMesh(self)
if self.parent is None:
self.flatten_meshes[mesh_dim_name] = flatten_mesh
else:
self.parent.flatten_meshes[mesh_dim_name] = flatten_mesh
return flatten_mesh

def size(self, mesh_dim: Optional[int] = None) -> int:
if mesh_dim is None:
if self.mesh is None:
return self.replicate_pg.size()
else:
assert self.mesh is not None
return self.mesh.size() * self.replicate_pg.size()
elif mesh_dim == self.replicate_dim:
return self.replicate_pg.size()
else:
return self.mesh.size(self._real_mesh_dim(mesh_dim))

@property
def ndim(self) -> int:
assert self.mesh is not None
return self.mesh.ndim + 1

@property
def shape(self) -> Tuple[int, ...]:
assert self.mesh is not None
ret: List[int] = list(self.mesh.shape)
ret.insert(self.replicate_dim, self.replicate_pg.size())
return tuple(ret)

def get_rank(self) -> int:
assert self.mesh is not None
return self.mesh.get_rank()

def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
if isinstance(mesh_dim, str):
dim = self.mesh_dim_names.index(mesh_dim)
else:
dim = 0 if mesh_dim is None else int(mesh_dim)

if mesh_dim is None:
if self.mesh is None:
return get_rank(self.replicate_pg)

assert self.replicate_dim == 0, "replicate_dim must be the first one"
assert self.mesh is not None
other_dim_size = self.mesh.size()
assert self.mesh is not None
other_dim_rank = self.mesh.get_local_rank()
replicate_pg_rank = get_rank(self.replicate_pg)
return other_dim_size * replicate_pg_rank + other_dim_rank
elif dim == self.replicate_dim:
return get_rank(self.replicate_pg)
else:
assert self.mesh is not None
return self.mesh.get_local_rank(self._real_mesh_dim(dim))

def get_coordinate(self) -> Optional[List[int]]:
"""
Return the relative indices of this rank relative to all
dimensions of the mesh. If this rank is not part of the mesh, return None.
"""
assert self.mesh is not None
return self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None

def get_all_groups(self) -> List[BaseProcessGroup]:
raise NotImplementedError


class _FlattenDeviceMesh(DeviceMesh):
def __init__(self, managed_mesh: ManagedDeviceMesh) -> None:
self.managed_mesh = managed_mesh

def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
raise NotImplementedError

def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> BaseProcessGroup:
raise NotImplementedError

def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh":
raise NotImplementedError

def size(self, mesh_dim: Optional[int] = None) -> int:
assert mesh_dim is None
return self.managed_mesh.size()

@property
def ndim(self) -> int:
raise NotImplementedError

@property
def shape(self) -> Tuple[int, ...]:
raise NotImplementedError

def get_rank(self) -> int:
raise NotImplementedError

def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
assert mesh_dim is None
return self.managed_mesh.get_local_rank()

def get_all_groups(self) -> List[BaseProcessGroup]:
raise NotImplementedError


def ft_init_device_mesh(
*,
device_type: str,
mesh_shape: Tuple[int, ...],
mesh_dim_names: Tuple[str, ...],
replicate_dim: int,
manager: "Manager",
) -> "ManagedDeviceMesh":
# We need to mislead DeviceMesh into thinking that replicate_dim has only
# 1 rank.
_mesh_shape = list(mesh_shape)
_mesh_shape.pop(replicate_dim)
_mesh_dim_names = list(mesh_dim_names)
_mesh_dim_names.pop(replicate_dim)
mesh = init_device_mesh(
device_type,
mesh_shape=tuple(_mesh_shape),
mesh_dim_names=tuple(_mesh_dim_names),
)

if device_type == "cpu":
pg = ProcessGroupGloo()
elif device_type == "cuda":
pg = ProcessGroupNCCL()
else:
raise ValueError()

manager._pg = pg
replicate_pg = ManagedProcessGroup(manager)
# We have to use MultiProcessTestCase, otherwise c10d will complain
# the same backend has been registered.
replicate_pg.register(mesh_dim_names[replicate_dim])

return ManagedDeviceMesh(
mesh=mesh,
mesh_dim_names=mesh_dim_names,
replicate_pg=replicate_pg,
replicate_dim=replicate_dim,
parent=None,
)
Loading

0 comments on commit 3e3a5ce

Please sign in to comment.