diff --git a/colossalai/engine/gradient_handler/__init__.py b/colossalai/engine/gradient_handler/__init__.py index 2dea768bad7e..633e9f885918 100644 --- a/colossalai/engine/gradient_handler/__init__.py +++ b/colossalai/engine/gradient_handler/__init__.py @@ -1,11 +1,10 @@ from ._base_gradient_handler import BaseGradientHandler from ._data_parallel_gradient_handler import DataParallelGradientHandler -from ._moe_gradient_handler import MoeGradientHandler from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler from ._zero_gradient_handler import ZeROGradientHandler __all__ = [ 'BaseGradientHandler', 'DataParallelGradientHandler', 'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler', - 'MoeGradientHandler', 'SequenceParallelGradientHandler' + 'SequenceParallelGradientHandler' ] diff --git a/colossalai/engine/gradient_handler/_moe_gradient_handler.py b/colossalai/engine/gradient_handler/_moe_gradient_handler.py deleted file mode 100644 index b499345d4e18..000000000000 --- a/colossalai/engine/gradient_handler/_moe_gradient_handler.py +++ /dev/null @@ -1,46 +0,0 @@ -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.core import global_context as gpc -from colossalai.registry import GRADIENT_HANDLER -from colossalai.utils.moe import get_moe_epsize_param_dict - -from ...context.parallel_mode import ParallelMode -from ._base_gradient_handler import BaseGradientHandler -from .utils import bucket_allreduce - - -@GRADIENT_HANDLER.register_module -class MoeGradientHandler(BaseGradientHandler): - """A helper class to handle all-reduce operations in a data parallel group and - moe model parallel. A all-reduce collective communication will be operated in - :func:`handle_gradient` among a data parallel group. - For better performance, it bucketizes the gradients of all parameters that are - the same type to improve the efficiency of communication. - - Args: - model (Module): Model where the gradients accumulate. - optimizer (Optimizer): Optimizer for updating the parameters. - """ - - def __init__(self, model, optimizer=None): - super().__init__(model, optimizer) - - def handle_gradient(self): - """A method running an all-reduce operation in a data parallel group. - Then running an all-reduce operation for all parameters in experts - across moe model parallel group - """ - global_data = gpc.data_parallel_size - - if global_data > 1: - epsize_param_dict = get_moe_epsize_param_dict(self._model) - - # epsize is 1, indicating the params are replicated among processes in data parallelism - # use the ParallelMode.DATA to get data parallel group - # reduce gradients for all parameters in data parallelism - if 1 in epsize_param_dict: - bucket_allreduce(param_list=epsize_param_dict[1], group=gpc.get_group(ParallelMode.DATA)) - - for ep_size in epsize_param_dict: - if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: - bucket_allreduce(param_list=epsize_param_dict[ep_size], - group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group) diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py index 56b11f4d9e08..25de4364cb39 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/nn/layer/moe/experts.py @@ -8,6 +8,7 @@ from colossalai.context import ParallelMode, seed from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.tensor.moe_tensor.api import set_moe_param_info from colossalai.utils import get_current_device from colossalai.zero.legacy.init_ctx import no_shard_zero_decrator @@ -50,7 +51,7 @@ def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args) # Attach parallel information for all parameters in Experts for exp in self.experts: for param in exp.parameters(): - param.__setattr__('moe_info', self.dist_info) + set_moe_param_info(param, self.dist_info) def forward(self, inputs: torch.Tensor): # Split inputs for each expert diff --git a/colossalai/tensor/moe_tensor/api.py b/colossalai/tensor/moe_tensor/api.py new file mode 100644 index 000000000000..11d07ef8c804 --- /dev/null +++ b/colossalai/tensor/moe_tensor/api.py @@ -0,0 +1,26 @@ +import torch + + +def is_moe_param(tensor: torch.Tensor) -> bool: + """ + Check whether the given tensor is a moe param. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + bool: Whether the given tensor is a moe param. + """ + return hasattr(tensor, "moe_info") + + +def set_moe_param_info(tensor: torch.Tensor, moe_info: dict) -> None: + """ + Set moe info for the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be set. + moe_info (dict): The moe info to be set. + + """ + tensor.__setattr__('moe_info', moe_info) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 7c00f0c450c1..3516e4df4bba 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -16,6 +16,7 @@ ) from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger +from colossalai.tensor.moe_tensor.api import is_moe_param # from colossalai.tensor import ColoParameter, ProcessGroup from colossalai.utils.cuda import get_current_device @@ -131,16 +132,23 @@ def __init__( self._grad_store = GradientStore(self.dp_pg, partition_grad=partition_grad) self._bucket_store = BucketStore(self.dp_pg) + # moe param should not be stored in working_groups + # because they have different parallel strategy + # so we need to store them separately in param_groups + # instead of working_groups + moe_params = list() + # iterate over the param group in the optimizer # partition these param groups for data parallel training # and add buffers to parameter store for future access for group_id, param_group in enumerate(self.optim.param_groups): group_params = list() for param in param_group['params']: - # skip moe param - if hasattr(param, "moe_info"): - continue if param.requires_grad: + # skip moe param + if is_moe_param(param): + moe_params.append(param) + continue group_params.append(param) # add the working params to working_param_groups for bookkeeping @@ -155,6 +163,15 @@ def __init__( # managed by this data parallel rank param_group['params'] = master_param_current_rank + # if there are moe params, store in addtional group in optim + if len(moe_params) > 0: + param_group = dict() + for key, value in self.optim.param_groups[0].items(): + if key != 'params': + param_group[key] = value + param_group['params'] = moe_params + self.optim.param_groups.append(param_group) + # intialize communication stream for # communication-compuation overlapping if self._overlap_communication: @@ -418,6 +435,11 @@ def step(self, closure=None): # update the parameters self.optim.step() + # release the moe grad + if len(self.param_groups) > len(self._working_param_groups): + for param in self.param_groups[-1]['params']: + param.grad = None + # release the grad grad_partition_groups = [] for group_id in range(self.num_param_groups): diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index 4b067c1ceea9..d86d78886e23 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -1,8 +1,15 @@ import torch.nn as nn from colossalai.context import MOE_CONTEXT +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.engine.gradient_handler._base_gradient_handler import BaseGradientHandler +from colossalai.engine.gradient_handler.utils import bucket_allreduce from colossalai.nn import CheckpointModule from colossalai.nn.layer import MoeModule +from colossalai.registry import GRADIENT_HANDLER +from colossalai.utils.moe import get_moe_epsize_param_dict class MoeModel(nn.Module): @@ -39,3 +46,41 @@ def forward(self, x): MOE_CONTEXT.add_loss(y) return x + + +@GRADIENT_HANDLER.register_module +class MoeGradientHandler(BaseGradientHandler): + """A helper class to handle all-reduce operations in a data parallel group and + moe model parallel. A all-reduce collective communication will be operated in + :func:`handle_gradient` among a data parallel group. + For better performance, it bucketizes the gradients of all parameters that are + the same type to improve the efficiency of communication. + + Args: + model (Module): Model where the gradients accumulate. + optimizer (Optimizer): Optimizer for updating the parameters. + """ + + def __init__(self, model, optimizer=None): + super().__init__(model, optimizer) + + def handle_gradient(self): + """A method running an all-reduce operation in a data parallel group. + Then running an all-reduce operation for all parameters in experts + across moe model parallel group + """ + global_data = gpc.data_parallel_size + + if global_data > 1: + epsize_param_dict = get_moe_epsize_param_dict(self._model) + + # epsize is 1, indicating the params are replicated among processes in data parallelism + # use the ParallelMode.DATA to get data parallel group + # reduce gradients for all parameters in data parallelism + if 1 in epsize_param_dict: + bucket_allreduce(param_list=epsize_param_dict[1], group=gpc.get_group(ParallelMode.DATA)) + + for ep_size in epsize_param_dict: + if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: + bucket_allreduce(param_list=epsize_param_dict[ep_size], + group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group) diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index a4473bf8eea4..87f0f4b2abe4 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -5,11 +5,11 @@ import colossalai from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.engine.gradient_handler import MoeGradientHandler from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, UniformNoiseGenerator from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device from colossalai.utils.moe import sync_moe_model_param +from tests.test_moe.moe_utils import MoeGradientHandler BATCH_SIZE = 4 DIM = 16 diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py index 83ec884b1515..e2acb0702f1c 100644 --- a/tests/test_moe/test_moe_zero_fwd_bwd.py +++ b/tests/test_moe/test_moe_zero_fwd_bwd.py @@ -6,11 +6,10 @@ from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel from colossalai.context import MOE_CONTEXT -from colossalai.engine.gradient_handler import MoeGradientHandler from colossalai.nn import MoeLoss from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all -from tests.test_moe.moe_utils import MoeModel +from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel def split_ddp_grad(grad, world_size): diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index fbd64def9b41..fcb6f95d1319 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -1,144 +1,92 @@ import pytest import torch -import torch.distributed as dist import colossalai -from colossalai.amp import convert_to_apex_amp +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel from colossalai.context import MOE_CONTEXT -from colossalai.engine.gradient_handler import MoeGradientHandler from colossalai.nn import MoeLoss -from colossalai.nn.optimizer import CPUAdam -from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from colossalai.zero.legacy.sharded_model import ShardedModelV2 -from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy -from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 -from colossalai.zero.low_level._utils import has_inf_or_nan -from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_moe.moe_utils import MoeModel - - -def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: - if loose: - return torch.allclose(tensor_a, tensor_b, atol=1e-2, rtol=1e-3) - return torch.allclose(tensor_a, tensor_b) - - -def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=False): - rank = dist.get_rank() - for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()): - if zero_p.colo_attr.param_is_sharded: - zero_p = zero_p.colo_attr.data_payload.to(p.device).float() - chunks = torch.flatten(p).chunk(dist.get_world_size()) - if rank >= len(chunks): - continue - p = chunks[rank].float() - if zero_p.size(0) > p.size(0): - zero_p = zero_p[:p.size(0)] - else: - zero_p = zero_p.colo_attr.data_payload.to(p.device) +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel - assert p.dtype == zero_p.dtype, "Parameter `{}`:\n{} vs {}".format(name, p.dtype, zero_p.dtype) - assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}' +def split_ddp_grad(grad, world_size): + with torch.no_grad(): + grad = grad.clone().detach().flatten() + padding_size = (world_size - grad.numel() % world_size) % world_size + if padding_size > 0: + grad = torch.nn.functional.pad(grad, [0, padding_size]) + splited_grad = grad.split(grad.numel() // world_size) + return splited_grad -def _run_step(model, optimizer, data, label, criterion, grad_handler): - model.train() - optimizer.zero_grad() - if criterion: - y = model(data) - loss = criterion(y, label) - else: - loss = model(data, label) +def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): + model.train() + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + loss = loss.float() - loss = loss.float() - if isinstance(model, ShardedModelV2): + if isinstance(model, LowLevelZeroModel): optimizer.backward(loss) else: loss.backward() + return y - if grad_handler is not None: - grad_handler.handle_gradient() - optimizer.step() - - -@parameterize("cpu_offload", [True]) -@parameterize("use_cpuadam", [True]) # We do not use Hybrid Adam right now, since it has a little bug -@parameterize("reuse_fp16_shard", [True, False]) -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def _run_test_sharded_optim_v2(cpu_offload, - shard_strategy_class, - use_cpuadam, - reuse_fp16_shard, - gpu_margin_mem_ratio=0.0): - shard_strategy = shard_strategy_class() - if use_cpuadam and cpu_offload is False: - return - MOE_CONTEXT.reset_loss() - get_components_func = non_distributed_component_funcs.get_callable('hanging_param_model') - _, train_dataloader, _, optimizer_class, _ = get_components_func() +def run_zero_optim_test(local_rank, world_size, stage=1): criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) - with ZeroInitContext(target_device=torch.device('cpu') if cpu_offload else get_current_device(), - shard_strategy=shard_strategy, - shard_param=True): - zero_model = MoeModel(checkpoint=True) - - zero_model = ShardedModelV2(zero_model, - shard_strategy, - tensor_placement_policy='cpu' if cpu_offload else 'cuda', - reuse_fp16_shard=reuse_fp16_shard) - - # check whether parameters are identical in ddp - for name, p in zero_model.named_parameters(): - if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated: - assert_equal_in_group(p.colo_attr.data_payload.to(get_current_device())) - - model = MoeModel(checkpoint=True).half() - col_model_deepcopy(zero_model, model) - model = model.cuda().float() - - if use_cpuadam: - optimizer_class = CPUAdam - optim = optimizer_class(model.parameters(), lr=1e-3) - sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3) - sharded_optim = ShardedOptimizerV2(zero_model, - sharded_optim, - initial_scale=2**5, - gpu_margin_mem_ratio=gpu_margin_mem_ratio) - - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False) - apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config) - apex_grad_handler = MoeGradientHandler(model) - - for i, (data, label) in enumerate(train_dataloader): - if i > 5: - break - data, label = data.cuda(), label.cuda() - _run_step(apex_model, apex_optimizer, data, label, criterion, apex_grad_handler) - _run_step(zero_model, sharded_optim, data, label, criterion, None) - check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam) - for param in model.parameters(): - assert not has_inf_or_nan(param) - - -def _run_dist(rank, world_size, port): + zero_model = MoeModel(checkpoint=True) + zero_optimizer = torch.optim.Adam(zero_model.parameters()) + plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") + booster = Booster(plugin=plugin) + zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) + + torch_model = MoeModel(checkpoint=True) + for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): + torch_param.data.copy_(zero_param.data) + torch_optimizer = torch.optim.Adam(torch_model.parameters()) + torch_model = torch_model.cuda() + grad_handler = MoeGradientHandler(torch_model) + + for _ in range(2): + data = torch.randn(16, 4).cuda() / (local_rank + 1) + label = torch.randint(0, 4, (16,)).cuda() + run_fwd_bwd(torch_model, data, label, criterion, None) + run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + grad_handler.handle_gradient() + + torch_optimizer.step() + zero_optimizer.step() + + for (torch_name, torch_param), (zero_name, zero_param) in zip(torch_model.named_parameters(), + zero_model.named_parameters()): + assert torch.allclose( + torch_param.data, + zero_param.data), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}" + + torch_optimizer.zero_grad() + zero_optimizer.zero_grad() + + +def run_dist(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') MOE_CONTEXT.setup(seed=42) - _run_test_sharded_optim_v2() + run_zero_optim_test(rank, world_size, stage=1) + run_zero_optim_test(rank, world_size, stage=2) -# use_cpuadam = True can be used with cpu_offload = False @pytest.mark.dist @pytest.mark.parametrize("world_size", [2]) @rerun_if_address_is_in_use() def test_moe_zero_optim(world_size): - spawn(_run_dist, world_size) + spawn(run_dist, world_size) if __name__ == '__main__': - test_moe_zero_optim(world_size=4) + test_moe_zero_optim(world_size=2)