diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index abc221fea2ad..c71e6c1f40c7 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -150,7 +150,8 @@ class HierarchicalAllToAll(torch.autograd.Function): def forward( ctx: Any, inputs: Tensor, - groups: Tuple[ProcessGroup], + groups: Tuple[ProcessGroup, ProcessGroup], + src_rank: int ) -> Tensor: """ Returns: @@ -159,12 +160,12 @@ def forward( # TODO: we can reduce comm volume by removing empty capacity if ctx is not None: ctx.comm_grps = groups + ctx.src_rank = src_rank intra_node_group, inter_node_group = groups local_world_size = dist.get_world_size(intra_node_group) num_group = dist.get_world_size(inter_node_group) if inter_node_group is not None else 1 world_size = local_world_size * num_group - src_rank = dist.get_process_group_ranks(intra_node_group)[0] outputs = torch.empty_like(inputs) if dist.get_rank() == src_rank: @@ -196,9 +197,10 @@ def forward( return outputs @staticmethod - def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None]: + def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: return ( - HierarchicalAllToAll.forward(None, grad_outputs[0], ctx.comm_grps), + HierarchicalAllToAll.forward(None, grad_outputs[0], ctx.comm_grps, ctx.src_rank), + None, None, ) diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index 2714d6316151..b768fb94a585 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -13,7 +13,7 @@ from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.routers import MoeRouter, get_router_cls from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator -from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size +from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size class SparseMLP(nn.Module): @@ -105,8 +105,11 @@ def __init__( if self.expert_parallel is not None: self.ep_group = get_ep_group(self.experts) self.ep_size = get_ep_size(self.experts) - self.ep_hierarchical_group = create_ep_hierarchical_group( - self.ep_group) if enable_hierarchical_comm else None + self.ep_hierarchical_group = None + if enable_hierarchical_comm: + self.ep_intra_src_rank, *self.ep_hierarchical_group = create_ep_hierarchical_group( + get_ep_group_ranks(self.experts) + ) self.dp_group = get_dp_group(self.experts) else: self.ep_group = None @@ -225,10 +228,10 @@ def _ep_process( """ if not overlap or dist.get_world_size(self.ep_group) == 1: if self.ep_hierarchical_group is not None: - expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group) + expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank) expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) expert_output = self.experts(expert_input) - expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group) + expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank) return expert_output else: expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0] diff --git a/colossalai/moe/utils.py b/colossalai/moe/utils.py index 5180f6ea6274..5a17a6e0d769 100644 --- a/colossalai/moe/utils.py +++ b/colossalai/moe/utils.py @@ -179,15 +179,15 @@ def set_moe_args(config: Any, args: dict): def create_ep_hierarchical_group( - ep_group: dist.ProcessGroup, + ep_group_ranks: List[int], nproc_per_node: Optional[int] = None, -) -> Tuple[Optional[dist.ProcessGroup], - Optional[dist.ProcessGroup]]: +) -> Tuple[int, dist.ProcessGroup, Optional[dist.ProcessGroup]]: """ e.g., If ep_group = [1, 2, 5, 6], and nproc_per_node = 4 Then, ep_intra_group = [1, 2] & [5, 6], ep_inter_group = [1, 5] & None """ assert dist.is_initialized(), "Please initialize torch.distributed first." + rank = dist.get_rank() if nproc_per_node is None: nproc_per_node = os.environ.get("LOCAL_WORLD_SIZE") assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually." @@ -197,24 +197,23 @@ def create_ep_hierarchical_group( "nproc_per_node should be a divisor of world_size." num_node = dist.get_world_size() // nproc_per_node - rank = dist.get_rank() - ep_ranks = dist.get_process_group_ranks(ep_group) - + intra_src_rank = None ep_intra_node_group = None for i in range(num_node): ep_intra_ranks = [ i * nproc_per_node + j for j in range(nproc_per_node) - if j in ep_ranks + if j in ep_group_ranks ] group = dist.new_group(ep_intra_ranks) if rank in ep_intra_ranks: assert ep_intra_node_group is None ep_intra_node_group = group + intra_src_rank = ep_intra_ranks[0] ep_inter_node_group = None ep_inter_ranks = [ - ep_ranks[0] + i * nproc_per_node + ep_group_ranks[0] + i * nproc_per_node for i in range(num_node) ] if len(ep_inter_ranks) > 1: @@ -222,4 +221,4 @@ def create_ep_hierarchical_group( if rank in ep_inter_ranks: ep_inter_node_group = group - return ep_intra_node_group, ep_inter_node_group + return intra_src_rank, ep_intra_node_group, ep_inter_node_group diff --git a/colossalai/tensor/moe_tensor/api.py b/colossalai/tensor/moe_tensor/api.py index c452f0d638fe..1e4486101dd3 100644 --- a/colossalai/tensor/moe_tensor/api.py +++ b/colossalai/tensor/moe_tensor/api.py @@ -1,3 +1,5 @@ +from typing import List + import torch import torch.distributed as dist from torch.distributed import ProcessGroup @@ -124,7 +126,7 @@ def get_dp_rank(tensor: torch.Tensor) -> int: return dist.get_rank(get_dp_group(tensor)) -def get_ep_group_ranks(tensor: torch.Tensor) -> int: +def get_ep_group_ranks(tensor: torch.Tensor) -> List[int]: """ Get the expert parallel group ranks of the given tensor. @@ -137,7 +139,7 @@ def get_ep_group_ranks(tensor: torch.Tensor) -> int: return tensor.moe_info.ep_group_ranks -def get_dp_group_ranks(tensor: torch.Tensor) -> int: +def get_dp_group_ranks(tensor: torch.Tensor) -> List[int]: """ Get the data parallel group ranks of the given tensor. diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index d5557a41f139..f87d4c792155 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -1,5 +1,6 @@ import os import warnings +from typing import Dict import pytest import torch @@ -123,7 +124,7 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_ local_param.data.copy_(all_param.data) -def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, seed: int): +def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, config: Dict): assert batch_size % world_size == 0 colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") @@ -133,8 +134,9 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) MOE_MANAGER.__init__() MOE_MANAGER.setup(parallel="EP") - os.environ["LOCAL_WORLD_SIZE"] = str(world_size) - enable_hierarchical_comm = torch.__version__ >= "1.13.1" + enable_hierarchical_comm = config.get("enable_hierarchical_comm", False) + if enable_hierarchical_comm: + os.environ["LOCAL_WORLD_SIZE"] = str(world_size) ep_model = SparseMLP( num_experts=num_experts, hidden_size=dim, @@ -161,7 +163,6 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size tp_grad_handler = MoeGradientHandler(tp_model) rank = dist.get_rank() - torch.cuda.manual_seed(seed) input_data = torch.randn(batch_size, dim, device=get_current_device()) micro_batch_size = batch_size // world_size index = rank * micro_batch_size @@ -218,11 +219,14 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size @pytest.mark.parametrize("num_experts", [4, 64]) @pytest.mark.parametrize("batch_size", [16]) @pytest.mark.parametrize("dim", [64]) -@pytest.mark.parametrize("seed", [42, 127]) +@pytest.mark.parametrize("config", [ + {"enable_hierarchical_comm": False}, + {"enable_hierarchical_comm": True}, +]) @rerun_if_address_is_in_use() -def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, seed: int): - spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, seed=seed) +def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, config: Dict): + spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, config=config) if __name__ == '__main__': - test_moe_ep_tp(num_experts=8, batch_size=32, dim=32, seed=42) + test_moe_ep_tp(num_experts=8, batch_size=32, dim=32)