Skip to content

Commit

Permalink
Introduce ManagedDeviceMesh to integrate DeviceMesh with TorchFT
Browse files Browse the repository at this point in the history
Summary:
ManagedDeviceMesh allow users to manipulate DeviceMesh with TorchFT ManagedProcessGroup.

ghstack-source-id: 888be370d2f8e81fbe0a9a29a9a99a4e6404cab8
Pull Request resolved: #56
  • Loading branch information
fegin committed Jan 7, 2025
1 parent f31d3b1 commit 1f5c9cd
Show file tree
Hide file tree
Showing 3 changed files with 335 additions and 13 deletions.
70 changes: 70 additions & 0 deletions torchft/fsdp_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, Tuple
from unittest import skipUnless, TestCase
from unittest.mock import Mock

import torch
import torch.distributed as dist
from torch import nn
from torch._C._distributed_c10d import (
_resolve_process_group,
AllgatherOptions,
AllreduceOptions,
BroadcastOptions,
ReduceOp,
)
from torch.distributed import (
_functional_collectives,
get_world_size,
ReduceOp,
TCPStore,
Work,
)
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import init_device_mesh
from torch.testing._internal.common_distributed import MultiProcessTestCase

from torchft.manager import Manager
from torchft.process_group import ft_init_device_mesh, ManagedProcessGroup


class FSDPTest(MultiProcessTestCase):
@property
def world_size(self):
return 4

def setUp(self):
super().setUp()
os.environ["TORCH_NCCL_DESYNC_DEBUG"] = "0"
self._spawn_processes()

def test_fsdp(self) -> None:
group_size = self.world_size // 2
group = self.rank // group_size
group_rank = self.rank % group_size

os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(12346 + group)
os.environ["RANK"] = str(group_rank)
os.environ["WORLD_SIZE"] = str(group_size)

manager = Mock(spec=Manager)
device_mesh = ft_init_device_mesh(
device_type="cuda",
mesh_shape=(2, 2),
mesh_dim_names=("dp_replicate", "dp_shard"),
replicate_dim=0,
manager=manager,
)
manager.num_participants.return_value = 1
model = nn.Linear(128, 128).cuda()
batch = torch.randn(4, 128).cuda()
shard_model = fully_shard(model, mesh=device_mesh)
shard_model(batch).mean().backward()
232 changes: 219 additions & 13 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import threading
from abc import ABC
from datetime import timedelta
from typing import TYPE_CHECKING, Dict, List, Optional, Type
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union

import torch
import torch.distributed as dist
Expand All @@ -38,6 +38,7 @@
Store,
TCPStore,
get_rank,
init_device_mesh,
)
from torch.distributed.distributed_c10d import Work, _world
from torch.futures import Future
Expand Down Expand Up @@ -130,17 +131,7 @@ def size(self) -> int:
def getBackendName(self) -> str:
raise NotImplementedError("not implemented")

def register(self, name: str) -> "ProcessGroup":
"""
Registers the process group with the global registry. This enables usage
with things like functional_collectives which are compilable.
This should only be called once.
Args:
name: name must be a unique name for this process group
"""

def _register(self, name: str) -> str:
group_name = f"{self.getBackendName()}:{name}"

# This is needed for DeviceMesh and functional collectives to work.
Expand All @@ -158,6 +149,21 @@ def create_pg(
devices = ["cpu"]
dist.Backend.register_backend(group_name, create_pg, devices=devices)

return group_name

def register(self, name: str) -> "ProcessGroup":
"""
Registers the process group with the global registry. This enables usage
with things like functional_collectives which are compilable.
This should only be called once.
Args:
name: name must be a unique name for this process group
"""

group_name = self._register(name)

return dist.new_group(
ranks=[dist.get_rank()],
backend=group_name,
Expand Down Expand Up @@ -496,6 +502,9 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
def size(self) -> int:
return self._manager.num_participants()

def getBackendName(self) -> str:
return self._manager._pg.getBackendName()


class _BabyWork(Work):
def __init__(
Expand Down Expand Up @@ -689,7 +698,6 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
logger.exception(f"got unexpected error in future handler: {e}")

def _get_future(self, op_id: int) -> Future[object]:

with self._futures_lock:
fut = Future() # pyre-fixme[29]: is not a function
self._futures[op_id] = fut
Expand Down Expand Up @@ -797,3 +805,201 @@ def extend_device_mesh(
mesh=mesh.mesh.unsqueeze(dim),
mesh_dim_names=tuple(mesh_dim_names),
)


class _ManagedDeviceMesh(DeviceMesh):
def __init__(
self,
mesh: Optional[DeviceMesh],
mesh_dim_names: Tuple[str],
replicate_pg: ManagedProcessGroup,
replicate_dim: int,
parent: Optional["_ManagedDeviceMesh"],
):
if mesh is None and parent is not None:
raise ValueError(
"_ManagedDeviceMesh doesn't support both mesh and parent are None."
)
self.mesh = mesh
self.mesh_dim_names = mesh_dim_names
self.replicate_pg = replicate_pg
self.replicate_dim = replicate_dim
self.replicate_dim_name = mesh_dim_names[replicate_dim]
self.parent = parent
self.flatten_meshes = {}
self.device_type = mesh.device_type if mesh is not None else parent.device_type
self._flatten_mesh_list = tuple()
self._thread_id = None

def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
if isinstance(mesh_dim_names, str):
if mesh_dim_names == self.replicate_dim_name:
return _ManagedDeviceMesh(
mesh=None,
mesh_dim_names=(mesh_dim_names,),
replicate_pg=self.replicate_pg,
replicate_dim=0,
parent=self,
)
elif mesh_dim_names in self.flatten_meshes:
return self.flatten_meshes[mesh_dim_names]
else:
return self.mesh[mesh_dim_names]
else:
assert isinstance(mesh_dim_names, tuple)
if self.replicate_dim_name in mesh_dim_names:
return self.mesh[mesh_dim_names]
else:
return _ManagedDeviceMesh(
self.mesh[mesh_dim_names],
mesh_dim_names,
self.replicate_pg,
mesh_dim_name.index(self.replicate_dim_name),
parent=self,
)

def _real_mesh_dim(self, mesh_dim: int) -> int:
return mesh_dim - 1 if mesh_dim > self.replicate_dim else mesh_dim

def get_group(self, mesh_dim: Optional[str] = None) -> BaseProcessGroup:
if mesh_dim is None:
assert self.mesh is None
return self.replicate_pg
elif mesh_dim == self.replicate_dim_name:
return self.replicate_pg
else:
return self.mesh.get_group(self._real_mesh_dim(mesh_dim))

def _flatten(self, mesh_dim_name: str) -> "DeviceMesh":
flatten_mesh = _FlattenDeviceMesh(self)
if self.parent is None:
self.flatten_meshes[mesh_dim_name] = flatten_mesh
else:
self.parent.flatten_meshes[mesh_dim_name] = flatten_mesh
return flatten_mesh

def size(self, mesh_dim: Optional[int] = None) -> int:
if mesh_dim is None:
if self.mesh is None:
return self.replicate_pg.size()
else:
return self.mesh.size() * self.replicate_pg.size()
elif mesh_dim == self.replicate_dim:
return self.replicate_pg.size()
else:
return self.mesh.size(self._real_mesh_dim(mesh_dim))

@property
def ndim(self) -> int:
return self.mesh.ndim + 1

@property
def shape(self) -> Tuple[int, ...]:
ret = list(self.mesh.shape)
ret.insert(self.replicate_dim, self.replicate_pg.size())

def get_rank(self) -> int:
return self.mesh.get_rank()

def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
if mesh_dim is None:
if self.mesh is None:
return get_rank(self.replicate_pg)

assert self.replicate_dim == 0, "replicate_dim must be the first one"
other_dim_size = self.mesh.size()
other_dim_rank = self.mesh.get_local_rank()
replicate_pg_rank = get_rank(self.replicate_pg)
return other_dim_size * replicate_pg_rank + other_dim_rank
elif mesh_dim in (self.replicate_dim, self.replicate_dim_name):
return get_rank(self.replicate_pg)
else:
return self.mesh.get_local_rank(self._real_mesh_dim(mesh_dim))

def get_coordinate(self) -> Optional[List[int]]:
"""
Return the relative indices of this rank relative to all
dimensions of the mesh. If this rank is not part of the mesh, return None.
"""
return self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None

def get_all_groups(self) -> List[ProcessGroup]:
raise NotImplementedError


class _FlattenDeviceMesh(DeviceMesh):
def __init__(self, managed_mesh: _ManagedDeviceMesh):
self.managed_mesh = managed_mesh

def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
raise NotImplementedError

def get_group(self, mesh_dim: Optional[str] = None) -> BaseProcessGroup:
raise NotImplementedError

def _flatten(self, mesh_dim_name: str) -> "DeviceMesh":
raise NotImplementedError

def size(self, mesh_dim: Optional[int] = None) -> int:
assert mesh_dim is None
return self.managed_mesh.size()

@property
def ndim(self) -> int:
raise NotImplementedError

@property
def shape(self) -> Tuple[int, ...]:
raise NotImplementedError

def get_rank(self) -> int:
raise NotImplementedError

def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
assert mesh_dim is None
return self.managed_mesh.get_local_rank()

def get_all_groups(self) -> List[ProcessGroup]:
raise NotImplementedError


def ft_init_device_mesh(
*,
device_type: str,
mesh_shape: Tuple[int, ...],
mesh_dim_names: Tuple[str, ...],
replicate_dim: int,
manager: "Manager",
):
# We need to mislead DeviceMesh into thinking that replicate_dim has only
# 1 rank.
_mesh_shape = list(mesh_shape)
_mesh_shape.pop(replicate_dim)
_mesh_dim_names = list(mesh_dim_names)
_mesh_dim_names.pop(replicate_dim)
mesh = init_device_mesh(
device_type,
mesh_shape=tuple(_mesh_shape),
mesh_dim_names=tuple(_mesh_dim_names),
)

if device_type == "cpu":
pg = ProcessGroupGloo()
elif device_type == "cuda":
pg = ProcessGroupNCCL()
else:
raise ValueError()

manager._pg = pg
replicate_pg = ManagedProcessGroup(manager)
# We have to use MultiProcessTestCase, otherwise c10d will complain
# the same backend has been registered.
replicate_pg.register(mesh_dim_names[replicate_dim])

return _ManagedDeviceMesh(
mesh=mesh,
mesh_dim_names=mesh_dim_names,
replicate_pg=replicate_pg,
replicate_dim=replicate_dim,
parent=None,
)
Loading

0 comments on commit 1f5c9cd

Please sign in to comment.