Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[moe] support low level zero optim #4429

Merged
merged 5 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions colossalai/engine/gradient_handler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from ._base_gradient_handler import BaseGradientHandler
from ._data_parallel_gradient_handler import DataParallelGradientHandler
from ._moe_gradient_handler import MoeGradientHandler
from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler
from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler
from ._zero_gradient_handler import ZeROGradientHandler

__all__ = [
'BaseGradientHandler', 'DataParallelGradientHandler', 'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler',
'MoeGradientHandler', 'SequenceParallelGradientHandler'
'SequenceParallelGradientHandler'
]
46 changes: 0 additions & 46 deletions colossalai/engine/gradient_handler/_moe_gradient_handler.py

This file was deleted.

3 changes: 2 additions & 1 deletion colossalai/nn/layer/moe/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from colossalai.context import ParallelMode, seed
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.tensor.moe_tensor.api import set_moe_param_info
from colossalai.utils import get_current_device
from colossalai.zero.legacy.init_ctx import no_shard_zero_decrator

Expand Down Expand Up @@ -50,7 +51,7 @@ def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args)
# Attach parallel information for all parameters in Experts
for exp in self.experts:
for param in exp.parameters():
param.__setattr__('moe_info', self.dist_info)
set_moe_param_info(param, self.dist_info)

def forward(self, inputs: torch.Tensor):
# Split inputs for each expert
Expand Down
26 changes: 26 additions & 0 deletions colossalai/tensor/moe_tensor/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch


def is_moe_param(tensor: torch.Tensor) -> bool:
"""
Check whether the given tensor is a moe param.

Args:
tensor (torch.Tensor): The tensor to be checked.

Returns:
bool: Whether the given tensor is a moe param.
"""
return hasattr(tensor, "moe_info")


def set_moe_param_info(tensor: torch.Tensor, moe_info: dict) -> None:
"""
Set moe info for the given tensor.

Args:
tensor (torch.Tensor): The tensor to be set.
moe_info (dict): The moe info to be set.

"""
tensor.__setattr__('moe_info', moe_info)
28 changes: 25 additions & 3 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.tensor.moe_tensor.api import is_moe_param
# from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.utils.cuda import get_current_device

Expand Down Expand Up @@ -131,16 +132,23 @@ def __init__(
self._grad_store = GradientStore(self.dp_pg, partition_grad=partition_grad)
self._bucket_store = BucketStore(self.dp_pg)

# moe param should not be stored in working_groups
# because they have different parallel strategy
# so we need to store them separately in param_groups
# instead of working_groups
moe_params = list()

# iterate over the param group in the optimizer
# partition these param groups for data parallel training
# and add buffers to parameter store for future access
for group_id, param_group in enumerate(self.optim.param_groups):
group_params = list()
for param in param_group['params']:
# skip moe param
if hasattr(param, "moe_info"):
continue
if param.requires_grad:
# skip moe param
if is_moe_param(param):
moe_params.append(param)
continue
group_params.append(param)

# add the working params to working_param_groups for bookkeeping
Expand All @@ -155,6 +163,15 @@ def __init__(
# managed by this data parallel rank
param_group['params'] = master_param_current_rank

# if there are moe params, store in addtional group in optim
if len(moe_params) > 0:
param_group = dict()
for key, value in self.optim.param_groups[0].items():
if key != 'params':
param_group[key] = value
param_group['params'] = moe_params
self.optim.param_groups.append(param_group)

# intialize communication stream for
# communication-compuation overlapping
if self._overlap_communication:
Expand Down Expand Up @@ -418,6 +435,11 @@ def step(self, closure=None):
# update the parameters
self.optim.step()

# release the moe grad
if len(self.param_groups) > len(self._working_param_groups):
for param in self.param_groups[-1]['params']:
param.grad = None

# release the grad
ver217 marked this conversation as resolved.
Show resolved Hide resolved
grad_partition_groups = []
for group_id in range(self.num_param_groups):
Expand Down
45 changes: 45 additions & 0 deletions tests/test_moe/moe_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import torch.nn as nn

from colossalai.context import MOE_CONTEXT
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.gradient_handler._base_gradient_handler import BaseGradientHandler
from colossalai.engine.gradient_handler.utils import bucket_allreduce
from colossalai.nn import CheckpointModule
from colossalai.nn.layer import MoeModule
from colossalai.registry import GRADIENT_HANDLER
from colossalai.utils.moe import get_moe_epsize_param_dict


class MoeModel(nn.Module):
Expand Down Expand Up @@ -39,3 +46,41 @@ def forward(self, x):

MOE_CONTEXT.add_loss(y)
return x


@GRADIENT_HANDLER.register_module
class MoeGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in a data parallel group and
moe model parallel. A all-reduce collective communication will be operated in
:func:`handle_gradient` among a data parallel group.
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.

Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
"""

def __init__(self, model, optimizer=None):
super().__init__(model, optimizer)

def handle_gradient(self):
"""A method running an all-reduce operation in a data parallel group.
Then running an all-reduce operation for all parameters in experts
across moe model parallel group
"""
global_data = gpc.data_parallel_size

if global_data > 1:
epsize_param_dict = get_moe_epsize_param_dict(self._model)

# epsize is 1, indicating the params are replicated among processes in data parallelism
# use the ParallelMode.DATA to get data parallel group
# reduce gradients for all parameters in data parallelism
if 1 in epsize_param_dict:
bucket_allreduce(param_list=epsize_param_dict[1], group=gpc.get_group(ParallelMode.DATA))

for ep_size in epsize_param_dict:
if ep_size != 1 and ep_size != MOE_CONTEXT.world_size:
bucket_allreduce(param_list=epsize_param_dict[ep_size],
group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group)
2 changes: 1 addition & 1 deletion tests/test_moe/test_grad_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

import colossalai
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.engine.gradient_handler import MoeGradientHandler
from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, UniformNoiseGenerator
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from colossalai.utils.moe import sync_moe_model_param
from tests.test_moe.moe_utils import MoeGradientHandler

BATCH_SIZE = 4
DIM = 16
Expand Down
3 changes: 1 addition & 2 deletions tests/test_moe/test_moe_zero_fwd_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.context import MOE_CONTEXT
from colossalai.engine.gradient_handler import MoeGradientHandler
from colossalai.nn import MoeLoss
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import MoeModel
from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel


def split_ddp_grad(grad, world_size):
Expand Down
Loading
Loading