diff --git a/llmfoundry/callbacks/__init__.py b/llmfoundry/callbacks/__init__.py index d9bb3c24a7..dc3ee707ac 100644 --- a/llmfoundry/callbacks/__init__.py +++ b/llmfoundry/callbacks/__init__.py @@ -11,6 +11,8 @@ from llmfoundry.callbacks.eval_gauntlet_callback import EvalGauntlet from llmfoundry.callbacks.fdiff_callback import FDiffMetrics from llmfoundry.callbacks.hf_checkpointer import HuggingFaceCheckpointer +from llmfoundry.callbacks.log_mbmoe_tok_per_expert_callback import \ + MegaBlocksMoE_TokPerExpert from llmfoundry.callbacks.monolithic_ckpt_callback import \ MonolithicCheckpointSaver from llmfoundry.callbacks.resumption_callbacks import (GlobalLRScaling, @@ -34,6 +36,7 @@ callbacks.register('scheduled_gc', func=ScheduledGarbageCollector) callbacks.register('oom_observer', func=OOMObserver) callbacks.register('eval_output_logging', func=EvalOutputLogging) +callbacks.register('mbmoe_tok_per_expert', func=MegaBlocksMoE_TokPerExpert) callbacks_with_config.register('async_eval', func=AsyncEval) callbacks_with_config.register('curriculum_learning', func=CurriculumLearning) @@ -46,6 +49,7 @@ 'ScheduledGarbageCollector', 'EvalGauntlet', 'HuggingFaceCheckpointer', + 'MegaBlocksMoE_TokPerExpert', 'AsyncEval', 'CurriculumLearning', ] diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index b7d80bd5f8..baa72a7f66 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Sequence, Union import torch +import torch.nn as nn from composer.core import Callback, Event, State, Time, TimeUnit from composer.core.state import fsdp_state_dict_type_context from composer.loggers import Logger, MLFlowLogger @@ -24,6 +25,7 @@ parse_uri) from composer.utils.misc import create_interval_scheduler from mlflow.transformers import _fetch_model_card, _write_license_information +from packaging import version from transformers import PreTrainedModel, PreTrainedTokenizerBase from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM @@ -312,28 +314,72 @@ def _save_checkpoint(self, state: State, logger: Logger): state_dict_model = state.model.model original_tokenizer = state.model.tokenizer - state_dict_context = fsdp_state_dict_type_context( - original_model, - state_dict_type='full') if ((not state.is_model_ddp) and isinstance( - state_dict_model, FSDP)) else contextlib.nullcontext() - - with state_dict_context: - state_dict = state_dict_model.state_dict() - - # convert the state dict to the requested precision - for k, v in state_dict.items(): - if isinstance(v, torch.Tensor): - state_dict[k] = v.to(dtype=self.dtype) + if version.parse(torch.__version__) > version.parse('2.2.9'): + from torch.distributed._tensor import DTensor + from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, get_model_state_dict) + cpu_offload = True + + # Add a dtensor->cpu tensor hook to avoid CUDA OOM + def dtensor_to_tensor_hook( + module: nn.Module, + state_dict: Dict[str, Any], + prefix: str, + *args: Any, + ) -> Dict[str, Any]: + dtensor_fqns = [] + for fqn in state_dict.keys(): + tensor = state_dict[fqn] + if isinstance(tensor, DTensor): + dtensor_fqns.append(fqn) + tensor = tensor.full_tensor() # type: ignore + if dist.get_global_rank() == 0: + if cpu_offload: + tensor = tensor.cpu() + state_dict[fqn] = tensor + if dist.get_global_rank() != 0: + for fqn in dtensor_fqns: + del state_dict[fqn] + return state_dict + + hooks = [] + for _, module in state_dict_model.named_modules(): + if isinstance(module, FSDP): + hooks.append( + module._register_state_dict_hook( + dtensor_to_tensor_hook)) + + state_dict = get_model_state_dict(state_dict_model, + options=StateDictOptions( + full_state_dict=True, + cpu_offload=cpu_offload)) + for hook in hooks: + hook.remove() + else: + state_dict_context = fsdp_state_dict_type_context( + original_model, state_dict_type='full') if ( + (not state.is_model_ddp) and isinstance( + state_dict_model, FSDP)) else contextlib.nullcontext() + with state_dict_context: + state_dict = state_dict_model.state_dict() + + # Convert the state dict to the requested precis + for k, v in state_dict.items(): + if isinstance(v, torch.Tensor): + state_dict[k] = v.to(dtype=self.dtype) new_model_instance = None # Need this for pyright because variable could be unbound if dist.get_global_rank() == 0: log.debug('Saving Hugging Face checkpoint in global rank 0') + # Edit HF config before building 2nd model copy copied_config = copy.deepcopy(original_model.config) if copied_config.model_type == 'mpt': copied_config.attn_config['attn_impl'] = 'torch' copied_config.init_device = 'cpu' + if 'moe_world_size' in getattr(copied_config, 'ffn_config', {}): + copied_config.ffn_config['moe_world_size'] = 1 log.debug(f'Creating new model instance') diff --git a/llmfoundry/callbacks/log_mbmoe_tok_per_expert_callback.py b/llmfoundry/callbacks/log_mbmoe_tok_per_expert_callback.py new file mode 100644 index 0000000000..fc906e0d87 --- /dev/null +++ b/llmfoundry/callbacks/log_mbmoe_tok_per_expert_callback.py @@ -0,0 +1,140 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Log tokens per expert for MegaBlocks MoE.""" +from __future__ import annotations + +import torch +from composer.core import Callback, State +from composer.loggers import Logger +from composer.utils import dist + + +class MegaBlocksMoE_TokPerExpert(Callback): + """Log tokens per expert for MegaBlocks MoE. + + To compute the load balancing loss, MegaBlocks caches information including `tokens_per_expert` + (tpe). At the :attr:`.Event.BATCH_END` event this callback gets load_balancing_loss from + MegaBlocks to get `tokens_per_expert` then logs statistics () of the number of tokens + assigned to experts for each layer index (l_idx) under ``mb_moe/layer__tpe``. + + + The tokens_per_expert statistics are logged by the :class:`.Logger` to the following keys as + described below. + + +----------------------------------+-----------------------------------------------------------+ + | Key | Logged data | + +==================================+===========================================================+ + | `mb_moe/alllayer_min_tpe` | Minimum tokens per expert across all layers | + +----------------------------------+-----------------------------------------------------------+ + | `mb_moe/alllayer_max_tpe` | Maximum tokens per expert across all layers | + +----------------------------------+-----------------------------------------------------------+ + | `mb_moe/alllayer_median_tpe` | Median tokens per expert across all layers | + +----------------------------------+-----------------------------------------------------------+ + | `mb_moe/alllayer_std_tpe` | Standard deviation of tokens per expert across all layers | + +----------------------------------+-----------------------------------------------------------+ + | `mb_moe/layer_min_tpe` | Minimum tokens per expert at l_idx layer | + +----------------------------------+-----------------------------------------------------------+ + | `mb_moe/layer_max_tpe` | Maximum tokens per expert at l_idx layer | + +----------------------------------+-----------------------------------------------------------+ + | `mb_moe/layer_median_tpe` | Median tokens per expert at l_idx layer | + +----------------------------------+-----------------------------------------------------------+ + | `mb_moe/layer_std_tpe` | Standard deviation of tokens per expert at l_idx layer | + +----------------------------------+-----------------------------------------------------------+ + + Args: + log_interval (int, optional): The interval on which to log (Default: 10). + log_every_layer (bool, optional): Enable logging ever layer's statisictics (True) or log + only aggregate statistics (Default: False). + all_reduce_stats (bool, optional): Enable aggregating statistics across gpus (True) or log + statistics for GPU 0 (Default: False). + normalize (bool, optional): Normalize token counts by total tokens (Default: True) or output + raw token count (False). When normalize is True, the callback displays the fraction of + unique tokens routed to each expert. When normalize is False, the callback displays the + total number of tokens routed to each expert. + """ + + def __init__( + self, + log_interval: int = 10, + log_every_layer: bool = False, + all_reduce_stats: bool = False, + normalize: bool = True, + ): + self.log_interval = log_interval + self.log_every_layer = log_every_layer + self.all_reduce_stats = all_reduce_stats + self.normalize = normalize + + self.topk = None + + def fit_start(self, state: State, logger: Logger) -> None: + if self.topk is None and self.normalize: + try: + from megablocks.layers.dmoe import dMoE + from megablocks.layers.moe import MoE + except: + raise RuntimeError( + 'Requirements for MegaBlocks not installed; see install instructions in `README.md`.' + ) + for module in state.model.modules(): + if isinstance(module, (MoE, dMoE)): + self.topk = module.experts.args.moe_top_k + return + + raise RuntimeError( + f'Callback not initialized correctly; self.topk not instantiated.' + ) + + def batch_end(self, state: State, logger: Logger) -> None: + if state.timestamp.batch.value % self.log_interval == 0: + try: + from megablocks.layers.moe import get_load_balancing_loss + except: + raise RuntimeError( + 'Requirements for MegaBlocks not installed; see install instructions in `README.md`.' + ) + tokens_per_expert, _ = zip(*get_load_balancing_loss()) + + tokens_per_expert = [ + tpe.clone().detach() for tpe in tokens_per_expert + ] + if self.all_reduce_stats: + for tpe in tokens_per_expert: + dist.all_reduce(tpe) + + if self.normalize: + tokens_per_expert = [ + tpe / (tpe.sum() / self.topk) for tpe in tokens_per_expert + ] + + all_tokens_per_expert = torch.concat(tokens_per_expert) + + min_tpe = all_tokens_per_expert.min().item() + max_tpe = all_tokens_per_expert.max().item() + median_tpe = all_tokens_per_expert.median().item() + std_tpe = all_tokens_per_expert.float().std().item() + + log_info = { + f'mb_moe/all_layers_min_tpe': min_tpe, + f'mb_moe/all_layers_max_tpe': max_tpe, + f'mb_moe/all_layers_median_tpe': median_tpe, + f'mb_moe/all_layers_std_tpe': std_tpe, + } + + if self.log_every_layer: + for l_idx, tpe_layer in enumerate(tokens_per_expert): + + min_tpe = tpe_layer.min().item() + max_tpe = tpe_layer.max().item() + median_tpe = tpe_layer.median().item() + std_tpe = tpe_layer.float().std().item() + + log_info.update({ + f'mb_moe/layer{l_idx}_min_tpe': min_tpe, + f'mb_moe/layer{l_idx}_max_tpe': max_tpe, + f'mb_moe/layer{l_idx}_median_tpe': median_tpe, + f'mb_moe/layer{l_idx}_std_tpe': std_tpe, + }) + + logger.log_metrics(log_info) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 38c9673a14..1d8711d280 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -170,6 +170,8 @@ def build_finetuning_dataloader(cfg: DictConfig, sampling_granularity=cfg.dataset.get('sampling_granularity', 1), batching_method=cfg.dataset.get('batching_method', 'random'), max_seq_len=cfg.dataset.max_seq_len, + allow_unsafe_types=cfg.dataset.get('allow_unsafe_types', False), + replication=cfg.dataset.get('replication', None), ) else: diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 4ca15e8d1f..42b15e4d6e 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -490,6 +490,12 @@ class StreamingFinetuningDataset(StreamingDataset): Defaults to ``1``. batching_method (str): Which batching method to use, either ``random``, ``stratified``, or ``per_stream``. Defaults to ``random``. + allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code + execution during deserialization, whether to keep going if ``True`` or raise an error + if ``False``. Defaults to ``False``. + replication (int, optional): Determines how many consecutive devices will receive the same + samples. Useful for training with tensor or sequence parallelism, where multiple + devices need to see the same partition of the dataset. Defaults to ``None``. """ def __init__(self, @@ -516,6 +522,8 @@ def __init__(self, sampling_granularity: int = 1, batching_method: str = 'random', max_seq_len: int = 2048, + allow_unsafe_types: bool = False, + replication: Optional[int] = None, **kwargs: Any): if len(kwargs) > 0: @@ -552,6 +560,8 @@ def __init__(self, sampling_method=sampling_method, sampling_granularity=sampling_granularity, batching_method=batching_method, + allow_unsafe_types=allow_unsafe_types, + replication=replication, ) self.tokenizer = tokenizer diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index e85968543c..fc31b890b0 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -83,6 +83,12 @@ class StreamingTextDataset(StreamingDataset): Defaults to ``1``. batching_method (str): Which batching method to use, either ``random``, ``stratified``, or ``per_stream``. Defaults to ``random``. + allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code + execution during deserialization, whether to keep going if ``True`` or raise an error + if ``False``. Defaults to ``False``. + replication (int, optional): Determines how many consecutive devices will receive the same + samples. Useful for training with tensor or sequence parallelism, where multiple + devices need to see the same partition of the dataset. Defaults to ``None``. """ def __init__(self, @@ -109,6 +115,8 @@ def __init__(self, sampling_method: str = 'balanced', sampling_granularity: int = 1, batching_method: str = 'random', + allow_unsafe_types: bool = False, + replication: Optional[int] = None, **kwargs: Any): if len(kwargs) > 0: @@ -151,6 +159,8 @@ def __init__(self, sampling_method=sampling_method, sampling_granularity=sampling_granularity, batching_method=batching_method, + allow_unsafe_types=allow_unsafe_types, + replication=replication, ) self.tokenizer = tokenizer self.max_seq_len = max_seq_len diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 42feb983d4..18b9f979f4 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -68,10 +68,146 @@ def __init__( ffn_config = { 'ffn_type': 'mptmlp', } + self.fuse_norm_attn_norm = kwargs.get('fuse_norm_attn_norm', False) del kwargs # unused, just to capture any extra args from the config super().__init__() + if self.fuse_norm_attn_norm: + self.norm_attn_norm = FusedNormAttentionNorm( + d_model=d_model, + n_heads=n_heads, + attn_config=attn_config, + ffn_config=ffn_config, + fc_type=fc_type, + resid_pdrop=resid_pdrop, + norm_type=norm_type, + device=device, + no_bias=no_bias, + ) + else: + assert isinstance(attn_config['attn_type'], str) + attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] + + # Necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs + args_to_exclude_in_attn_class = { + 'attn_type', 'alibi', 'attn_uses_sequence_id', 'alibi_bias_max', + 'rope', 'rope_theta', 'rope_impl', 'rope_dail_config', + 'rope_hf_config' + } + attn_config_subset_for_attn_class = { + k: v + for k, v in attn_config.items() + if k not in args_to_exclude_in_attn_class + } + + self.norm_1 = build_norm( + name=norm_type.lower(), + normalized_shape=d_model, + device=device, + ) + self.attn = attn_class( + d_model=d_model, + n_heads=n_heads, + fc_type=fc_type, + device=device, + **attn_config_subset_for_attn_class, + bias=not no_bias, + ) + self.norm_2 = None + if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], + '_has_norm', False): + self.norm_2 = build_norm( + name=norm_type.lower(), + normalized_shape=d_model, + device=device, + ) + + self.ffn = build_ffn( + d_model=d_model, + expansion_ratio=expansion_ratio, + device=device, + bias=not no_bias, + **ffn_config, + ) + self.resid_attn_dropout = nn.Dropout(resid_pdrop) + self.resid_ffn_dropout = nn.Dropout(resid_pdrop) + self.use_pad_tok_in_ffn = use_pad_tok_in_ffn + + def forward( + self, + x: torch.Tensor, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attn_bias: Optional[torch.Tensor] = None, + rotary_emb_w_meta_info: Optional[Dict] = None, + attention_mask: Optional[torch.ByteTensor] = None, + is_causal: bool = True, + output_attentions: bool = False, + alibi_slopes: Optional[torch.Tensor] = None, + flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ + torch.Tensor, torch.Tensor]]]: + if self.fuse_norm_attn_norm: + x, m, attn_weights, past_key_value = self.norm_attn_norm( + x, + past_key_value=past_key_value, + attn_bias=attn_bias, + rotary_emb_w_meta_info=rotary_emb_w_meta_info, + attention_mask=attention_mask, + is_causal=is_causal, + output_attentions=output_attentions, + alibi_slopes=alibi_slopes, + flash_attn_padding_info=flash_attn_padding_info, + ) + else: + a = self.norm_1(x) + b, attn_weights, past_key_value = self.attn( + a, + past_key_value=past_key_value, + attn_bias=attn_bias, + rotary_emb_w_meta_info=rotary_emb_w_meta_info, + attention_mask=attention_mask, + is_causal=is_causal, + needs_weights=output_attentions, + alibi_slopes=alibi_slopes, + flash_attn_padding_info=flash_attn_padding_info, + ) + x = x + self.resid_attn_dropout(b) + m = x + if self.norm_2 is not None: + m = self.norm_2(x) + + batch_size, seq_len = m.size()[:2] + indices = None + if not self.use_pad_tok_in_ffn: + assert unpad_input is not None + m, indices, _, _ = unpad_input(m, attention_mask) + n = self.ffn(m) + if not self.use_pad_tok_in_ffn: + assert pad_input is not None + n = pad_input(n, indices, batch_size, seq_len) + x = x + self.resid_ffn_dropout(n) + return x, attn_weights, past_key_value + + +class FusedNormAttentionNorm(nn.Module): + + def __init__( + self, + d_model: int, + n_heads: int, + attn_config: Optional[Dict] = None, + ffn_config: Optional[Dict] = None, + fc_type: str = 'torch', + resid_pdrop: float = 0.0, + norm_type: str = 'low_precision_layernorm', + device: Optional[str] = None, + no_bias: bool = False, + **kwargs: Any, + ): + super().__init__() + assert attn_config is not None + assert ffn_config is not None assert isinstance(attn_config['attn_type'], str) attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] @@ -86,7 +222,6 @@ def __init__( for k, v in attn_config.items() if k not in args_to_exclude_in_attn_class } - self.norm_1 = build_norm( name=norm_type.lower(), normalized_shape=d_model, @@ -108,17 +243,7 @@ def __init__( normalized_shape=d_model, device=device, ) - self.ffn = build_ffn( - d_model=d_model, - expansion_ratio=expansion_ratio, - device=device, - bias=not no_bias, - **ffn_config, - ) self.resid_attn_dropout = nn.Dropout(resid_pdrop) - self.resid_ffn_dropout = nn.Dropout(resid_pdrop) - - self.use_pad_tok_in_ffn = use_pad_tok_in_ffn def forward( self, @@ -131,8 +256,8 @@ def forward( output_attentions: bool = False, alibi_slopes: Optional[torch.Tensor] = None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ - torch.Tensor, torch.Tensor]]]: + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor, torch.Tensor]]]: a = self.norm_1(x) b, attn_weights, past_key_value = self.attn( a, @@ -149,14 +274,5 @@ def forward( m = x if self.norm_2 is not None: m = self.norm_2(x) - batch_size, seq_len = m.size()[:2] - indices = None - if not self.use_pad_tok_in_ffn: - assert unpad_input is not None - m, indices, _, _ = unpad_input(m, attention_mask) - n = self.ffn(m) - if not self.use_pad_tok_in_ffn: - assert pad_input is not None - n = pad_input(n, indices, batch_size, seq_len) - x = x + self.resid_ffn_dropout(n) - return x, attn_weights, past_key_value + + return x, m, attn_weights, past_key_value diff --git a/llmfoundry/models/layers/dmoe.py b/llmfoundry/models/layers/dmoe.py new file mode 100644 index 0000000000..1a981b61c5 --- /dev/null +++ b/llmfoundry/models/layers/dmoe.py @@ -0,0 +1,246 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable + +import torch + + +# Add option to route tokens uniformly across experts. We use +# a custom autograd op router backwards is still run for benchmarking. +class _UniformExpertAssignment(torch.autograd.Function): + + @staticmethod + def forward( + ctx, # pyright: ignore[reportMissingParameterType] + x: torch.Tensor, + num_experts: int): + out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) + out = torch.remainder(out, num_experts) + return out.view(x.shape) + + +class LearnedRouter(torch.nn.Module): + + def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int, + moe_jitter_eps: float, moe_normalize_expert_weights: bool, + uniform_expert_assignment: bool, device: torch.device) -> None: + super().__init__() + self.hidden_size: int = hidden_size + self.moe_num_experts: int = moe_num_experts + self.moe_top_k: int = moe_top_k + self.moe_jitter_eps: float = moe_jitter_eps + self.moe_normalize_expert_weights: bool = moe_normalize_expert_weights + self.uniform_expert_assignment: bool = uniform_expert_assignment + + self.layer: torch.nn.Module = torch.nn.Linear( + hidden_size, + moe_num_experts, + bias=False, + device=device, + ) + + def jitter(self, x: torch.Tensor) -> torch.Tensor: + low: float = 1.0 - self.moe_jitter_eps + high: float = 1.0 + self.moe_jitter_eps + noise: torch.Tensor = torch.rand(x.size(), + dtype=x.dtype, + device=x.device) + return low + noise * (high - low) + + def _top_k(self, scores: torch.Tensor) -> torch.Tensor: + if self.moe_top_k == 1: + return scores.max( + dim=-1) # pyright: ignore[reportGeneralTypeIssues] + return torch.topk(scores, self.moe_top_k, + dim=-1) # pyright: ignore[reportGeneralTypeIssues] + + def forward(self, x: torch.Tensor): + if self.training and self.moe_jitter_eps is not None: + x = x * self.jitter(x) + + scores = self.layer(x.view(-1, x.shape[-1])).softmax(dim=-1) + expert_weights, top_experts = self._top_k(scores) + if self.moe_normalize_expert_weights: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=self.moe_normalize_expert_weights, + dim=-1, + keepdim=True) + + top_experts = (_UniformExpertAssignment.apply(top_experts, + self.moe_num_experts) + if self.uniform_expert_assignment else top_experts) + scores = scores.to(x.dtype) + expert_weights = expert_weights.to(x.dtype) + return scores, expert_weights, top_experts + + +class MLP(torch.nn.Module): + + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + moe_num_experts: int, + activation_fn: Callable, + device: torch.device, + ) -> None: + super().__init__() + + self.moe_num_experts: int = moe_num_experts + self.ffn_hidden_size: int = ffn_hidden_size + self.hidden_size: int = hidden_size + self.activation_fn: Callable = activation_fn + + self.w1 = torch.nn.Parameter( + torch.rand(moe_num_experts * ffn_hidden_size, + hidden_size, + device=device)) + self.w2 = torch.nn.Parameter( + torch.rand(moe_num_experts * ffn_hidden_size, + hidden_size, + device=device)) + self.activation_fn = activation_fn + + def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: + expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, + self.hidden_size)[expert_idx] + expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, + self.hidden_size)[expert_idx] + + before_activation = x @ expert_w1.t() + layer_1_output = self.activation_fn(before_activation) + output = layer_1_output @ expert_w2 + return output + + +class GLU(torch.nn.Module): + + def __init__(self, hidden_size: int, ffn_hidden_size: int, + moe_num_experts: int, activation_fn: Callable, + device: torch.device): + super().__init__() + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.moe_num_experts = moe_num_experts + + self.w1 = torch.nn.Parameter( + torch.rand(moe_num_experts * ffn_hidden_size, + hidden_size, + device=device)) + self.v1 = torch.nn.Parameter( + torch.rand(moe_num_experts * ffn_hidden_size, + hidden_size, + device=device)) + self.w2 = torch.nn.Parameter( + torch.rand(moe_num_experts * ffn_hidden_size, + hidden_size, + device=device)) + self.activation_fn = activation_fn + + def forward(self, x: torch.Tensor, expert_idx: torch.Tensor): + expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, + self.hidden_size)[expert_idx] + expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, + self.hidden_size)[expert_idx] + expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, + self.hidden_size)[expert_idx] + + x1 = x.matmul(expert_w1.t()) + x2 = x.matmul(expert_v1.t()) + x1 = self.activation_fn(x1) + x1 = x1 * x2 + x1 = x1.matmul(expert_w2) + return x1 + + +class DroplessMLP(torch.nn.Module): + + def __init__(self, hidden_size: int, ffn_hidden_size: int, mlp_type: str, + moe_num_experts: int, activation_fn: Callable, bias: bool, + device: torch.device): + super().__init__() + self.moe_num_experts = moe_num_experts + + if mlp_type == 'mlp': + self.mlp = MLP(hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + moe_num_experts=moe_num_experts, + activation_fn=activation_fn, + device=device) + elif mlp_type == 'glu': + self.mlp = GLU(hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + moe_num_experts=moe_num_experts, + activation_fn=activation_fn, + device=device) + else: + raise ValueError(f'Received unknown {mlp_type=}') + + def forward(self, x: torch.Tensor, scores: torch.Tensor, + expert_weights: torch.Tensor, top_experts: torch.Tensor): + in_shape = x.shape + hidden_size = in_shape[-1] + + x = x.view(-1, hidden_size) + out = torch.zeros_like(x) + + expert_mask = torch.nn.functional.one_hot( + top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0) + for expert_idx in range(0, self.moe_num_experts): + topk_idx, token_idx = torch.where(expert_mask[expert_idx]) + if token_idx.shape[0] == 0: + continue + # In torch it is faster to index using lists than torch tensors + token_list = token_idx.tolist() + topk_list = topk_idx.tolist() + + expert_tokens = x[None, token_list].reshape(-1, hidden_size) + mlp_output = self.mlp(expert_tokens, expert_idx) + expert_out = mlp_output * expert_weights[token_list, topk_list, + None] + + out.index_add_(0, token_idx, expert_out) + + out = out.view(in_shape) + return out + + +class dMoE(torch.nn.Module): + + def __init__(self, hidden_size: int, ffn_hidden_size: int, + moe_num_experts: int, moe_top_k: int, mlp_type: str, + activation_fn: Callable, moe_jitter_eps: float, + moe_normalize_expert_weights: bool, + uniform_expert_assignment: bool, bias: bool, + device: torch.device): + super().__init__() + + # Token router. + self.router = LearnedRouter( + hidden_size, + moe_num_experts=moe_num_experts, + moe_top_k=moe_top_k, + moe_jitter_eps=moe_jitter_eps, + moe_normalize_expert_weights=moe_normalize_expert_weights, + uniform_expert_assignment=uniform_expert_assignment, + device=device, + ) + + # Expert computation helper. + self.experts = DroplessMLP( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + mlp_type=mlp_type, + moe_num_experts=moe_num_experts, + activation_fn=activation_fn, + bias=bias, + device=device, + ) + + def forward(self, x: torch.Tensor): + # Compute the expert scores and assignments. + scores, expert_weights, top_experts = self.router(x) + # Compute the experts. + return self.experts(x, scores, expert_weights, top_experts) diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index 9389cf385f..48d3d8c267 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -6,17 +6,26 @@ import logging from copy import deepcopy from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, List, Optional, Union import torch import torch.nn as nn +from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Shard +from llmfoundry.models.layers.dmoe import dMoE from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY try: import transformer_engine.pytorch as te -except: - te = None + is_te_imported = True +except ModuleNotFoundError: + is_te_imported = False + +try: + import megablocks + is_megablocks_imported = True +except ModuleNotFoundError: + is_megablocks_imported = False log = logging.getLogger(__name__) @@ -79,6 +88,18 @@ def resolve_ffn_hidden_size( return ffn_hidden_size +def dtensorify_param(param: nn.Parameter, mesh: DeviceMesh, + placements: List[Placement]): + """Construct a DTensor from an already sharded local parameter.""" + param_dtensor = DTensor.from_local( + param.data, + device_mesh=mesh, + placements=placements, + run_check=False, + ) + return nn.Parameter(param_dtensor) + + class MPTMLP(nn.Module): def __init__( @@ -144,7 +165,6 @@ def __init__( **self.fc_kwargs, ) - @torch.compile def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) @@ -152,12 +172,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: FFN_CLASS_REGISTRY = { 'mptmlp': MPTMLP, 'mptglu': MPTGLU, + 'torch_dmoe': dMoE, } -if te is not None: +if is_te_imported: + import transformer_engine.pytorch as te te.LayerNormMLP._has_norm = True FFN_CLASS_REGISTRY['te_ln_mlp'] = te.LayerNormMLP +if is_megablocks_imported: + import megablocks + + FFN_CLASS_REGISTRY['mb_moe'] = megablocks.layers.moe.MoE + FFN_CLASS_REGISTRY['mb_dmoe'] = megablocks.layers.dmoe.dMoE + def build_ffn( d_model: int, @@ -185,7 +213,10 @@ def build_ffn( bias=bias, ) elif ffn_type == 'te_ln_mlp': - assert te is not None + if te is None: + raise RuntimeError( + 'Requirements for TransformerEngine not installed; see install instructions in `README.md`.' + ) ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, ffn_hidden_size) if ffn_act_fn is not None: @@ -198,5 +229,99 @@ def build_ffn( bias=bias, **kwargs, ) + elif ffn_type in ('mb_moe', 'mb_dmoe'): + if megablocks is None: + raise RuntimeError( + 'Requirements for megablocks not installed; see install instructions in `README.md`.' + ) + args = kwargs['args'] + args.bias = bias + args.hidden_size = d_model + args.device = device + + ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, + ffn_hidden_size) + args.ffn_hidden_size = ffn_hidden_size + + if ffn_act_fn is not None: + args.activation_fn = resolve_ffn_act_fn(ffn_act_fn) + + moe_world_size = 1 + expert_parallel_group = args.expert_parallel_group + if expert_parallel_group is not None: + moe_world_size = expert_parallel_group.size() + if kwargs.get('moe_world_size') != moe_world_size: + raise RuntimeError( + f'MoE expert_parallel_group configured with incorrect world size.' + ) + + if ffn_type == 'mb_moe': + ffn = megablocks.layers.moe.MoE(args) + + # Fused initialization setup + # For param_init_fn, enables shape based init of stacked layers + ffn.experts.mlp._stack_dim = 0 + elif ffn_type == 'mb_dmoe': + ffn = megablocks.layers.dmoe.dMoE(args) + + # Fused initialization setup + # For param_init_fn, enables shape based init of fused layers + n_exp = min(1, args.moe_num_experts // moe_world_size) + ffn.experts.mlp._fused = (0, [ + (n + 1) * args.ffn_hidden_size for n in range(n_exp - 1) + ]) + else: + raise RuntimeError(f'Invalid ffn_type option: {ffn_type}.') + + # Attach args to MLP directly for use in param_init_fn + ffn.experts.mlp.hidden_size = args.ffn_hidden_size + ffn.experts.mlp.expert_parallel_group = expert_parallel_group + ffn.experts.mlp.weight_parallel_group = args.weight_parallel_group + + if moe_world_size > 1: + device_mesh = kwargs['device_mesh'] + + expert_mesh = device_mesh['expert_parallel'] + expert_placements: List[Placement] = [Shard(0)] + # Register in two loops as you cannot overwrite parameters while iterating over named_parameters() + dtensorified_params = [ + (name, + dtensorify_param(param=parameter, + mesh=expert_mesh, + placements=expert_placements)) + for name, parameter in ffn.experts.mlp.named_parameters() + ] + for name, dtensorified_param in dtensorified_params: + ffn.experts.mlp.register_parameter(name, dtensorified_param) + + device_mesh = kwargs['device_mesh'] + if device_mesh.mesh.ndim == 2: + submesh = device_mesh['weight_parallel'] + elif device_mesh.mesh.ndim == 3: + raise RuntimeError(f'HSDP + MoE is not supported.') + else: + raise ValueError( + f'{device_mesh.mesh.ndim=} not supported for MoE.') + + ffn.experts._fsdp_kwargs_dict = { + 'device_mesh': submesh, + } + return ffn + elif ffn_type == 'torch_dmoe': + return dMoE( + hidden_size=d_model, + ffn_hidden_size=resolve_ffn_hidden_size(d_model, expansion_ratio, + ffn_hidden_size), + moe_num_experts=kwargs.pop('moe_num_experts'), + moe_top_k=kwargs.pop('moe_top_k'), + mlp_type=kwargs.pop('mlp_type'), + bias=bias, + moe_jitter_eps=kwargs.pop('moe_jitter_eps'), + activation_fn=resolve_ffn_act_fn(ffn_act_fn), + moe_normalize_expert_weights=kwargs.pop( + 'moe_normalize_expert_weights'), + uniform_expert_assignment=kwargs.pop('uniform_expert_assignment'), + device=device, # pyright: ignore[reportGeneralTypeIssues] + ) raise ValueError(f'{ffn_type=} not recognized.') diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 2f58ea312e..8383d33ec0 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -19,6 +19,7 @@ from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY # type: ignore (see note) from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note) from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note) +from llmfoundry.models.layers.dmoe import dMoE # type: ignore (see note) from llmfoundry.models.layers.layer_builders import build_norm # type: ignore (see note) from llmfoundry.layers_registry import norms # type: ignore (see note) from llmfoundry.utils.registry_utils import construct_from_registry # type: ignore (see note) @@ -290,6 +291,8 @@ def _validate_config(self) -> None: ) elif self.ffn_config['ffn_type'] in ['mptmlp', 'mptglu']: self.ffn_config['fc_type'] = self.fc_type + elif self.ffn_config['ffn_type'] in ['mb_moe', 'mb_dmoe']: + self.ffn_config['return_bias'] = False elif self.ffn_config['ffn_type'] == 'te_ln_mlp': self.ffn_config['bias'] = not self.no_bias if 'ffn_act_fn' in self.ffn_config.keys(): diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index d54b797269..4a8f3943af 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -10,6 +10,7 @@ import math import warnings +from functools import cached_property from typing import (Any, Dict, List, Mapping, MutableMapping, Optional, Tuple, Union) @@ -49,6 +50,9 @@ from llmfoundry.models.layers.ffn import build_ffn as build_ffn from llmfoundry.models.layers.layer_builders import build_norm from llmfoundry.models.mpt.configuration_mpt import MPTConfig +from llmfoundry.models.utils.config_moe_args import config_moe_args +from llmfoundry.models.utils.mpt_param_count import (mpt_get_active_params, + mpt_get_total_params) # NOTE: All utils are imported directly even if unused so that # HuggingFace can detect all the needed files to copy into its modules folder. @@ -276,6 +280,8 @@ def _fsdp_wrap_fn( module: nn.Module, ) -> bool: # FSDP Wrap function for MPT Models + if hasattr(module, '_fsdp_kwargs_dict'): + return module._fsdp_kwargs_dict return isinstance(module, MPTBlock) @@ -316,10 +322,20 @@ def __init__(self, config: MPTConfig): config.d_model, device=config.init_device) self.emb_drop = nn.Dropout(config.emb_pdrop) + self.mb_args = None + block_args = config.to_dict() + if block_args['ffn_config']['ffn_type'] in ('mb_moe', 'mb_dmoe'): + block_args['ffn_config'] = config_moe_args( + block_args['ffn_config'], + config.d_model, + config.expansion_ratio, + config.n_layers, + ) + self.mb_args = block_args['ffn_config'].get('args') self.blocks = nn.ModuleList([ MPTBlock( device=config.init_device, - **config.to_dict(), + **block_args, ) for _ in range(config.n_layers) ]) @@ -980,8 +996,6 @@ def __init__( allow_embedding_resizing=True, ) - self.n_active_params = sum(p.numel() for p in self.parameters()) - loss_fn_config = om_model_config.get('loss_fn', 'fused_crossentropy') if loss_fn_config == 'fused_crossentropy': try: @@ -1012,6 +1026,15 @@ def get_targets(self, batch: Mapping) -> torch.Tensor: return targets def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: + if self.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'): + # Clear MegaBlocks MoE load balancing loss cache + try: # Add try/catch to avoid transformers complaining and raising errors + from megablocks.layers.moe import clear_load_balancing_loss + except: + raise RuntimeError( + 'Requirements for MegaBlocks not installed; see install instructions in `README.md`.' + ) + clear_load_balancing_loss() return self.model( input_ids=batch.get('input_ids', None), attention_mask=batch.get('attention_mask', None), @@ -1020,7 +1043,7 @@ def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: ) def loss(self, outputs: CausalLMOutputWithPast, - batch: Mapping) -> torch.Tensor: + batch: Mapping) -> Union[dict, torch.Tensor]: targets = self.get_targets(batch) losses = self.loss_fn(outputs.logits.view(-1, outputs.logits.size(-1)), targets.view(-1)) @@ -1030,18 +1053,40 @@ def loss(self, outputs: CausalLMOutputWithPast, else: loss = losses.sum() / (targets != self.loss_fn.ignore_index).sum() + if self.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'): + # MegaBlocks MoE load balancing loss + try: # Add try/catch to avoid transformers complaining and raising errors + from megablocks.layers.moe import batched_load_balancing_loss + except: + raise RuntimeError( + 'Requirements for MegaBlocks not installed; see install instructions in `README.md`.' + ) + lbl = batched_load_balancing_loss(self.model.transformer.mb_args) + return { + 'total': loss + lbl, + 'loss': loss, + 'lbl': lbl, + } + return loss - def flops_per_batch(self, batch: Mapping) -> int: + @cached_property + def n_total_params(self): + """Gets the total number of parameters in the model.""" + return mpt_get_total_params(self) + + @cached_property + def n_active_params(self): + """Gets the total number of active parameters in the model.""" + return mpt_get_active_params(self) + + def flops_per_batch(self, batch: Mapping): # Note: this computation does not take into account padding, and assumes # that the dataset has been constructed without padding. Additionally, we # assume the backward pass is approximately 2x the forward pass bs, msl = batch['input_ids'].shape[0:2] params = self.n_active_params - if not self.model.transformer.config.tie_word_embeddings: - # embedding layers are lookup tables, therefore are not counted in the FLOP computation - params -= self.model.transformer.wte.weight.numel() params_flops_per_token = 2 * params params_flops_per_seq = params_flops_per_token * msl attn_flops_per_seq = (self.model.config.n_layers * 2 * 2 * diff --git a/llmfoundry/models/utils/__init__.py b/llmfoundry/models/utils/__init__.py index 7c808ff449..41313b8729 100644 --- a/llmfoundry/models/utils/__init__.py +++ b/llmfoundry/models/utils/__init__.py @@ -1,8 +1,11 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +from llmfoundry.models.utils.config_moe_args import config_moe_args from llmfoundry.models.utils.meta_init_context import (init_empty_weights, init_on_device) +from llmfoundry.models.utils.mpt_param_count import (mpt_get_active_params, + mpt_get_total_params) from llmfoundry.models.utils.param_init_fns import (MODEL_INIT_REGISTRY, generic_param_init_fn_) @@ -11,4 +14,7 @@ 'init_on_device', 'generic_param_init_fn_', 'MODEL_INIT_REGISTRY', + 'config_moe_args', + 'mpt_get_active_params', + 'mpt_get_total_params', ] diff --git a/llmfoundry/models/utils/act_ckpt.py b/llmfoundry/models/utils/act_ckpt.py index bde7c92bd7..1975865f1b 100644 --- a/llmfoundry/models/utils/act_ckpt.py +++ b/llmfoundry/models/utils/act_ckpt.py @@ -7,7 +7,7 @@ from llmfoundry.layers_registry import norms from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY -from llmfoundry.models.layers.blocks import MPTBlock +from llmfoundry.models.layers.blocks import FusedNormAttentionNorm, MPTBlock from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY @@ -25,6 +25,8 @@ def get_act_ckpt_module(mod_name: str) -> Any: """Get the module type from the module name.""" if mod_name.lower() == 'mptblock': mod_type = MPTBlock + elif mod_name.lower() == 'norm_attn_norm': + mod_type = FusedNormAttentionNorm elif mod_name in ATTN_CLASS_REGISTRY: mod_type = ATTN_CLASS_REGISTRY[mod_name] elif mod_name in FFN_CLASS_REGISTRY: diff --git a/llmfoundry/models/utils/config_moe_args.py b/llmfoundry/models/utils/config_moe_args.py new file mode 100644 index 0000000000..b69cd18348 --- /dev/null +++ b/llmfoundry/models/utils/config_moe_args.py @@ -0,0 +1,167 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Helper function to configure MPT with MoEs.""" + +from typing import Union + +import torch +from packaging import version +from torch import distributed + +from llmfoundry.models.layers.ffn import resolve_ffn_hidden_size + + +def create_set_process_group(k: int): + """Creates a new distributed group using sets of k GPUs. + + For example, if you have 16 GPUs and input k=4, the resulting process groups + will have ranks: + process group 0 ranks: [ 0, 1, 2, 3] + process group 1 ranks: [ 4, 5, 6, 7] + process group 2 ranks: [ 8, 9, 10, 11] + process group 3 ranks: [12, 13, 14, 15] + + Args: + k (int): Number of GPUs to use in set size. + + Returns: + A handle of distributed group that can be given to collective calls. + """ + world_size = distributed.get_world_size() + if world_size % k != 0: + raise RuntimeError(f'{world_size=} must be divisible by {k=}.') + start = distributed.get_rank() // k * k + ranks = tuple(range(start, start + k)) + return distributed.new_group(ranks) + + +def config_megablocks_moe_args( + ffn_config: dict, + d_model: int, + expansion_ratio: Union[int, float], + n_layers: int, +) -> dict: + """Configures `ffn_config` for MegaBlocks MoE. + + We prepare all necessary arguments for `megablocks.layers.arguments.Arguments` so that process + groups can be initialized and shared across all blocks in the network. + + Args: + ffn_config (dict): FFN configuation before the MegaBlocks MoE is configured. + d_model (int): Hidden size of the network. + expansion_ratio (Union[int, float]): Expansion ratio in FFN. + n_layers (int): Number of blocks used in the network. + + Returns: + ffn_config (dict): FFN configuration with MegaBlocks MoE configured. + """ + try: + import megablocks + except: + raise RuntimeError( + 'Requirements for MegaBlocks not installed; see install instructions in `README.md`.' + ) + + ffn_config.setdefault('fp16', False) + ffn_config.setdefault('bf16', False) + ffn_config['num_layers'] = n_layers + + ffn_type = ffn_config.pop('ffn_type') + fc_type = ffn_config.pop('fc_type') + ffn_act_fn = ffn_config.pop('ffn_act_fn', None) + + # Config for MegaBlocks MoE world size and device mesh + world_size = 1 # default + moe_world_size = ffn_config.pop('moe_world_size') + device_mesh = None + device_mesh_cfg = ffn_config.pop('device_mesh', None) + if moe_world_size > 1: + if version.parse(torch.__version__.split('.dev')[0]) < version.parse( + '2.2.0'): # type: ignore + raise RuntimeError( + 'MoE world size > 1 is not supported in torch version {torch.__version__}<2.2.' + ) + + from torch.distributed._tensor.device_mesh import init_device_mesh + + world_size = distributed.get_world_size() + if world_size < moe_world_size or world_size % moe_world_size: + raise ValueError( + f'Invalid world size configuration: {world_size=} and {moe_world_size=}' + ) + + # FSDP + if device_mesh_cfg is None or len(device_mesh_cfg) == 1: + if device_mesh_cfg is not None: + world_size = device_mesh_cfg[0] + sharding_group_dim = world_size // moe_world_size + device_mesh = init_device_mesh( + 'cuda', + (sharding_group_dim, moe_world_size), + mesh_dim_names=('weight_parallel', 'expert_parallel'), + ) + else: + raise ValueError(f'{device_mesh_cfg=} must be length 1') + + ffn_config['moe_expert_model_parallelism'] = True + ffn_config['expert_parallel_group'] = device_mesh[ + 'expert_parallel'].get_group(0) # type: ignore + + lbl_process_group = ffn_config.get('lbl_process_group', None) + if lbl_process_group is not None: + if lbl_process_group == 'expert_group': + lbl_process_group = ffn_config['expert_parallel_group'] + elif lbl_process_group == 'global_group': + lbl_process_group = distributed.group.WORLD + elif isinstance(lbl_process_group, int): + lbl_process_group = create_set_process_group(lbl_process_group) + elif lbl_process_group is not None: + raise ValueError( + f'Unknown {lbl_process_group=}. Options are: none | expert_group | global_group | .' + ) + ffn_config['lbl_process_group'] = lbl_process_group + + ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio) + ffn_config.setdefault('ffn_hidden_size', ffn_hidden_size) + + args = megablocks.layers.arguments.Arguments( + hidden_size=d_model, + **ffn_config, + ) + ffn_config['args'] = args + ffn_config['device_mesh'] = device_mesh + ffn_config['moe_world_size'] = moe_world_size + ffn_config['ffn_type'] = ffn_type + ffn_config['fc_type'] = fc_type + ffn_config['ffn_act_fn'] = ffn_act_fn + + return ffn_config + + +def config_moe_args( + ffn_config: dict, + d_model: int, + expansion_ratio: Union[int, float], + n_layers: int, +) -> dict: + """Configures `ffn_config` for MoE. + + Args: + ffn_config (dict): FFN configuation before the MoE is configured. + d_model (int): Hidden size of the network. + expansion_ratio (int, float): Expansion ratio in FFN. + n_layers (int): Number of blocks used in the network. + + Returns: + ffn_config (dict): FFN configuration with MoE configured. + """ + if ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'): + return config_megablocks_moe_args( + ffn_config=ffn_config, + d_model=d_model, + expansion_ratio=expansion_ratio, + n_layers=n_layers, + ) + else: + raise ValueError(f'Invalid ffn_type ({ffn_config["ffn_type"]}).') diff --git a/llmfoundry/models/utils/meta_init_context.py b/llmfoundry/models/utils/meta_init_context.py index c22c226c28..d72a289a73 100644 --- a/llmfoundry/models/utils/meta_init_context.py +++ b/llmfoundry/models/utils/meta_init_context.py @@ -21,6 +21,7 @@ import torch import torch.nn as nn +from torch.distributed._tensor import DTensor @contextmanager @@ -86,11 +87,13 @@ def register_empty_parameter(self: torch.nn.Module, name: str, if param is not None: parameter = self._parameters[name] assert parameter is not None - - param_cls = type(parameter) - kwargs = parameter.__dict__ - - self._parameters[name] = param_cls(parameter.to(device), **kwargs) + if isinstance(parameter, DTensor): + self._parameters[name] = parameter.to(device) # type: ignore + else: + param_cls = type(parameter) + kwargs = parameter.__dict__ + self._parameters[name] = param_cls(parameter.to(device), + **kwargs) def register_empty_buffer(self: torch.nn.Module, name: str, diff --git a/llmfoundry/models/utils/mpt_param_count.py b/llmfoundry/models/utils/mpt_param_count.py new file mode 100644 index 0000000000..d90929713b --- /dev/null +++ b/llmfoundry/models/utils/mpt_param_count.py @@ -0,0 +1,167 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Helper functions for computing parameter counts for MPT model. + +Use if generic `sum(p.numel() for p in self.parameters())` +style computation does not account for MoE parameter sharding. +The helper functions in this file account for MoE parameter +sharding in the parameter count calculation. The functions below +calculate the total parameter count and the active parameter count. +Note: MPT has both n_total_params and n_active_params methods. +""" + +from typing import Union + +from torch import Tensor, nn +from torch.distributed._tensor import DTensor + + +def module_n_params(module: nn.Module) -> int: + """Gets the number of parameters in this module excluding child modules. + + Args: + module (nn.Module): Module of which we get the number of parameters. + + Returns: + An int for the number of parameters in this module. + """ + n_params = 0 + for p in module.parameters(recurse=False): + n_params += p.numel() + return n_params + + +def _dtensor_safe_check_numel(tensor: Union[Tensor, DTensor]) -> int: + if isinstance(tensor, DTensor): + tensor = tensor._local_tensor + return tensor.numel() + + +def megablocks_n_total_params(mpt_model) -> int: # type: ignore + """Calculates the number of parameters in a MegaBlocks enabled MPT model. + + MoE experts are sharded across workers. This function scans for MegaBlocks + modules then multiplies expert params count by MoE world size. + + Args: + mpt_model (ComposerMPTCausalLM): MPT model of which the number of + parameters is calculated. + + Returns: + An int for the total number of parameters in this MPT model. + """ + import megablocks + + moe_world_size = mpt_model.config.ffn_config.get('moe_world_size') + + if mpt_model.config.ffn_config.get('moe_weight_parallelism', False): + # If MegaBlocks shards experts, the total sharding world size + # must be increased by the degree to which MegaBlocks shards the + # experts. + mb_args = mpt_model.model.transformer.mb_args + moe_world_size *= mb_args.weight_parallel_group.size() + + n_total_params = 0 + for module in mpt_model.modules(): + if isinstance( + module, + (megablocks.layers.mlp.SparseMLP, megablocks.layers.mlp.MLP)): + n_w1 = _dtensor_safe_check_numel(module.w1) + n_total_params += n_w1 * moe_world_size + n_w2 = _dtensor_safe_check_numel(module.w2) + n_total_params += n_w2 * moe_world_size + + # GLU has an extra weight + if hasattr(module, 'v1'): + n_v1 = _dtensor_safe_check_numel(module.v1) + n_total_params += n_v1 * moe_world_size + else: + n_total_params += module_n_params(module) + + return n_total_params + + +def megablocks_n_active_params(mpt_model) -> int: # type: ignore + """Calculates the number of active parameters in a MegaBlocks enabled MPT. + + This requires we calculate the number of elements per expert and + multiply this by top k. + + Args: + mpt_model (ComposerMPTCausalLM): MPT model of which the number of + active parameters is calculated. + + Returns: + An int for the active number of parameters in this MPT model. + """ + import megablocks + + moe_num_experts = mpt_model.config.ffn_config.get('moe_num_experts', 1) + moe_world_size = mpt_model.config.ffn_config.get('moe_world_size') + + local_experts = moe_num_experts / moe_world_size # if local_experts is < 1, then the expert is sharded + if mpt_model.config.ffn_config.get('moe_weight_parallelism', False): + mb_args = mpt_model.model.transformer.mb_args + local_experts /= mb_args.weight_parallel_group.size() + + moe_top_k = mpt_model.config.ffn_config.get('moe_top_k', 1) + n_active_params = 0 + for module in mpt_model.modules(): + if isinstance( + module, + (megablocks.layers.mlp.SparseMLP, megablocks.layers.mlp.MLP)): + n_w1 = _dtensor_safe_check_numel(module.w1) + n_active_params += int(n_w1 / local_experts * moe_top_k) + n_w2 = _dtensor_safe_check_numel(module.w2) + n_active_params += int(n_w2 / local_experts * moe_top_k) + + # GLU has an extra weight + if hasattr(module, 'v1'): + n_v1 = _dtensor_safe_check_numel(module.v1) + n_active_params += int(n_v1 / local_experts * moe_top_k) + else: + n_active_params += module_n_params(module) + + return n_active_params + + +def mpt_get_total_params(mpt_model) -> int: # type: ignore + """Calculates the total paramter count of an MPT model. + + Note: Must be called before model parameters are sharded by FSDP. + + Args: + mpt_model (ComposerMPTCausalLM): MPT model of which the number of + active parameters is calculated. + + Returns: + An int for the total number of parameters in this MPT model. + """ + if mpt_model.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'): + return megablocks_n_total_params(mpt_model) + else: + return sum(p.numel() for p in mpt_model.parameters()) + + +def mpt_get_active_params(mpt_model) -> int: # type: ignore + """Calculates the total paramter count of an MPT model. + + Note: Must be called before model parameters are sharded by FSDP. + + Args: + mpt_model (ComposerMPTCausalLM): MPT model of which the number of + active parameters is calculated. + + Returns: + An int for the active number of parameters in this MPT model. + """ + if mpt_model.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'): + params = megablocks_n_active_params(mpt_model) + else: + params = sum(p.numel() for p in mpt_model.parameters()) + if not mpt_model.model.transformer.config.tie_word_embeddings: + # Embedding layers are lookup tables, therefore are not counted in the FLOP computation + params -= _dtensor_safe_check_numel( + mpt_model.model.transformer.wte.weight) + return params diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index 35dc88a408..16376de451 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -4,13 +4,16 @@ import math import warnings from collections.abc import Sequence +from copy import deepcopy from functools import partial from typing import Any, Callable, Optional, Tuple, Union import torch from torch import nn +from torch.distributed._tensor import DTensor from llmfoundry.layers_registry import norms +from llmfoundry.models.layers.dmoe import GLU, MLP from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY try: @@ -18,6 +21,11 @@ except: te = None +try: + import megablocks +except: + megablocks = None + def torch_default_param_init_fn_( module: nn.Module, @@ -30,27 +38,114 @@ def torch_default_param_init_fn_( module.reset_parameters() -def fused_init_helper_(module: nn.Module, init_fn_: Callable) -> None: - # parameter initialization is often based on the parameters shape. - # If a layer is fused, initialization should be based on the shapes - # of the original tensor instead of the shape of the fused tensor. - # Layers which are fused should have the _fused attribute defined. - # The first element of _fused is the dimension along which the tensor is fused. - # This is followed by an iterable of split indices." - +def fused_init_helper_( + module: nn.Module, + init_fn_: Callable, + name_param: str = 'weight', +): + """Initializes parameters which have been fused for efficiency purposes. + + Parameter initialization is often based on the parameters shape. If a layer is fused, + initialization should be based on the shapes of the original tensor instead of the + shape of the fused tensor. Layers which are fused should have the _fused + attribute. First element of _fused is the dimension along which the tensor is fused. + Second element is a an iterable of split indices. + + Args: + module (nn.Module): The module to initialize. + init_fn_ (Callable): Initialization method. + name_param (str): Name of parameter to initalize within the module. + """ _fused = getattr(module, '_fused', None) - if _fused is None: raise RuntimeError(f'Internal logic error') - assert isinstance(module.weight, torch.Tensor) + fused_param_init_helper(getattr(module, name_param), init_fn_, _fused) + - dim, splits = _fused - splits = (0, *splits, module.weight.size(dim)) +def fused_param_init_helper( + param: torch.Tensor, + init_fn_: Callable, + fused_parameters: tuple[int, list[int]], +): + """Initializes parameters that are fused together. + + Args: + param (torch.Tensor): Tensor to initialize. + init_fn_ (Callable): Initialization method. + fused_parameters (tuple[int, list[int]]): First element of _fused is the dimension + along which the tensor is fused. Second element is a an iterable of split indices. + """ + p_ndims = param.ndim + dim, splits = fused_parameters + splits = (0, *splits, param.size(dim)) # type: ignore for s, e in zip(splits[:-1], splits[1:]): - slice_indices = [slice(None)] * module.weight.ndim + slice_indices = [slice(None)] * p_ndims # type: ignore slice_indices[dim] = slice(s, e) - init_fn_(module.weight[slice_indices]) + init_fn_(param[slice_indices]) # type: ignore + + +def stacked_init_helper_( + module: nn.Module, + init_fn_: Callable, + name_param: str = 'weight', +): + """Initializes parameters stacked along a new dimention. + + Parameter initialization is often based on the parameters shape. If a layer is stacked, + initialization should be based on the shapes of the original tensor instead of the + shape of the stacked tensor. Layers which are fused should have the _stacked_dim + attribute defining the new dimension along which they are stacked. + + Args: + module (nn.Module): The module to initialize. + init_fn_ (Callable): Initialization method. + name_param (str): Name of parameter to initalize within the module. + """ + stack_dim = getattr(module, '_stack_dim', None) + if stack_dim is None: + raise RuntimeError(f'Internal logic error') + + stacked_param_init_helper(getattr(module, name_param), init_fn_, stack_dim) + + +def stacked_param_init_helper( + param: torch.Tensor, + init_fn_: Callable, + stack_dim: int, +): + """Initialize parameters stacked along a new dimention. + + Args: + param (torch.Tensor): Tensor to initialize. + init_fn_ (Callable): Initialization method. + stack_dim (int): Dimention along with parameters are stacked + """ + p_ndims = param.ndim + + for idx in range(param.size(stack_dim)): + slice_indices = [slice(None)] * p_ndims # type: ignore + slice_indices[stack_dim] = idx # type: ignore + init_fn_(param[slice_indices]) # type: ignore + + +def _flip_fan_mode(init_fn_: Callable): + """Changes the mode of an init_fn_. + + init_fn_'s "mode" is set to operate on standard torch modules eg torch.nn.Linear. + If a custom layer transposes its weights before they are allied such that it is + opposite pytorch's conventions, we must flip the fan mode, from fan_in to fan_out. + + Args: + init_fn_ (Callable): Initialization method. + """ + _init_fn_ = deepcopy(init_fn_) + if 'mode' in _init_fn_.keywords: + if _init_fn_.keywords['mode'] == 'fan_in': + _init_fn_.keywords['mode'] = 'fan_out' + elif _init_fn_.keywords['mode'] == 'fan_out': + _init_fn_.keywords['mode'] = 'fan_in' + return _init_fn_ def generic_param_init_fn_( @@ -191,6 +286,35 @@ def generic_param_init_fn_( with torch.no_grad(): module.fc2_weight.div_(div_is_residual) # type: ignore + elif megablocks is not None and isinstance(module, ( + megablocks.layers.moe.MoE, + megablocks.layers.dmoe.dMoE, + megablocks.layers.moe.ParallelMLP, + megablocks.layers.dmoe.ParallelDroplessMLP, + )): + if hasattr(module, 'bias') and module.bias is not None: + # Initialize bias to 0 + torch.nn.init.zeros_(module.bias) # type: ignore + elif megablocks is not None and isinstance(module, + megablocks.layers.glu.SparseGLU): + _megablocks_sparse_glu_generic_param_init_fn_( + module, init_fn_, bool(init_div_is_residual), div_is_residual) + elif megablocks is not None and isinstance(module, + megablocks.layers.mlp.SparseMLP): + _megablocks_sparse_mlp_generic_param_init_fn_( + module, init_fn_, bool(init_div_is_residual), div_is_residual) + elif megablocks is not None and isinstance(module, + megablocks.layers.mlp.MLP): + _megablocks_mlp_generic_param_init_fn_(module, init_fn_, + bool(init_div_is_residual), + div_is_residual) + elif isinstance(module, GLU): + init_fn_(module.w1) + init_fn_(module.v1) + init_fn_(module.w2) + elif isinstance(module, MLP): + init_fn_(module.w1) + init_fn_(module.w2) else: for _ in module.parameters(recurse=False): # raise error if uninitialized module has any parameters @@ -199,7 +323,197 @@ def generic_param_init_fn_( ) -def _normal_init_(std: float, mean: float = 0.0) -> Callable: +def _megablocks_sparse_mlp_generic_param_init_fn_( + module: nn.Module, + init_fn_: Callable, + init_div_is_residual: bool = False, + div_is_residual: float = 1.0, +): + """Initializes MegaBlocks MLP. + + To enable elastic deterministic initialization, this method creates the entire + weight matrix then slice into the weight tensors such that the sampled weights + should not vary between moe world size for the same random seed. + + Args: + module (nn.Module): The module to initialize. + init_fn_ (Callable): Initialization method. + init_div_is_residual (bool): Flag enabling parameters tagged with _is_residual + flag to be divided by div_is_residual. + div_is_residual (float): The value by which parameter initialization is divided + if init_div_is_residual flag is enabled. + """ + expert_process_group_size, rank, weight_parallel_group_size, weight_parallel_group_rank = 1, 0, 1, 0 + if module.expert_parallel_group is not None: + expert_process_group_size = int( + module.expert_parallel_group.size()) # type: ignore + rank = int(module.expert_parallel_group.rank()) # type: ignore + if module.weight_parallel_group is not None: + weight_parallel_group_size = int( + module.weight_parallel_group.size()) # type: ignore + weight_parallel_group_rank = int( + module.weight_parallel_group.rank()) # type: ignore + + hidden_size = int(module.hidden_size) # type: ignore + + # Initialize w1 + w1 = module.w1 + if isinstance(w1, DTensor): + w1 = w1._local_tensor + w1_size = list(w1.shape) # type: ignore + w1_size[ + 0] = w1_size[0] * expert_process_group_size * weight_parallel_group_size + + n_exp = w1_size[0] // hidden_size + _fused = (0, [(n + 1) * hidden_size for n in range(n_exp - 1)]) + + _w1 = w1.new_empty(w1_size) # type: ignore + fused_param_init_helper(_w1, init_fn_, _fused) + _w1_local = _w1.chunk(expert_process_group_size, dim=0)[rank] + _w1_local_slice = _w1_local.chunk(weight_parallel_group_size, + dim=0)[weight_parallel_group_rank] + with torch.no_grad(): + w1.copy_(_w1_local_slice) # type: ignore + + # Initialize w2 + w2 = module.w2 + if isinstance(w2, DTensor): + w2 = w2._local_tensor + w2_size = list(w2.shape) # type: ignore + w2_size[ + 0] = w2_size[0] * expert_process_group_size * weight_parallel_group_size + _w2 = w2.new_empty(w2_size) # type: ignore + # MegaBlocks operates on w2 as x @ w2, so needs flipped fan mode + fused_param_init_helper(_w2, _flip_fan_mode(init_fn_), _fused) + _w2_local = _w2.chunk(expert_process_group_size, dim=0)[rank] + _w2_local_slice = _w2_local.chunk(weight_parallel_group_size, + dim=0)[weight_parallel_group_rank] + with torch.no_grad(): + w2.copy_(_w2_local_slice) # type: ignore + if init_div_is_residual is not False: + with torch.no_grad(): + w2.div_(div_is_residual) # type: ignore + + +def _megablocks_sparse_glu_generic_param_init_fn_( + module: nn.Module, + init_fn_: Callable, + init_div_is_residual: bool = False, + div_is_residual: float = 1.0, +): + """Initializes MegaBlocks Sparse GLU. + + Extends the Megablocks Sparse MLP case to an additional weight v1 for GLUs. + This additional weight v1 has the same initialization procedure as w1 for MLPs. + + Args: + module (nn.Module): The module to initialize. + init_fn_ (Callable): Initialization method. + init_div_is_residual (bool): Flag enabling parameters tagged with _is_residual + flag to be divided by div_is_residual. + div_is_residual (float): The value by which parameter initialization is divided + if init_div_is_residual flag is enabled. + """ + # Init for w1 and w2 matrices + _megablocks_sparse_mlp_generic_param_init_fn_( + module=module, + init_fn_=init_fn_, + init_div_is_residual=init_div_is_residual, + div_is_residual=div_is_residual) + + # Init ported from _megablocks_sparse_mlp_generic_param_init_fn_ for v1 + expert_process_group_size, rank, weight_parallel_group_size, weight_parallel_group_rank = 1, 0, 1, 0 + if module.expert_parallel_group is not None: + expert_process_group_size = int( + module.expert_parallel_group.size()) # type: ignore + rank = int(module.expert_parallel_group.rank()) # type: ignore + if module.weight_parallel_group is not None: + weight_parallel_group_size = int( + module.weight_parallel_group.size()) # type: ignore + weight_parallel_group_rank = int( + module.weight_parallel_group.rank()) # type: ignore + + hidden_size = int(module.hidden_size) # type: ignore + + # Separately initialize v1 + v1 = module.v1 + if isinstance(v1, DTensor): + v1 = v1._local_tensor + v1_size = list(v1.shape) # type: ignore + v1_size[ + 0] = v1_size[0] * expert_process_group_size * weight_parallel_group_size + + n_exp = v1_size[0] // hidden_size + _fused = (0, [(n + 1) * hidden_size for n in range(n_exp - 1)]) + + _v1 = v1.new_empty(v1_size) # type: ignore + fused_param_init_helper(_v1, init_fn_, _fused) + _v1_local = _v1.chunk(expert_process_group_size, dim=0)[rank] + _v1_local_slice = _v1_local.chunk(weight_parallel_group_size, + dim=0)[weight_parallel_group_rank] + with torch.no_grad(): + v1.copy_(_v1_local_slice) # type: ignore + + +def _megablocks_mlp_generic_param_init_fn_( + module: nn.Module, + init_fn_: Callable, + init_div_is_residual: bool = False, + div_is_residual: float = 1.0, +): + """Initializes MegaBlocks' MLP. + + To enable elastic deterministic initialization, this method creates the entire + weight matrix then slice into the weight tensors such that the sampled weights + should not vary between moe world size for the same random seed. + + Args: + module (nn.Module): The module to initialize. + init_fn_ (Callable): Initialization method. + init_div_is_residual (bool): Flag enabling parameters tagged with _is_residual + flag to be divided by div_is_residual. + div_is_residual (float): The value by which parameter initialization is divided + if init_div_is_residual flag is enabled. + """ + expert_process_group_size, rank, weight_parallel_group_size, w_rank = 1, 0, 1, 0 + if module.expert_parallel_group is not None: + expert_process_group_size = int( + module.expert_parallel_group.size()) # type: ignore + rank = int(module.expert_parallel_group.rank()) # type: ignore + if module.weight_parallel_group is not None: + weight_parallel_group_size = int( + module.weight_parallel_group.size()) # type: ignore + w_rank = int(module.weight_parallel_group.rank()) # type: ignore + + _init_fn_ = _flip_fan_mode(init_fn_) + + # Initialize w1 + w1_size = list(module.w1.shape) # type: ignore + w1_size[0] = w1_size[0] * expert_process_group_size + w1_size[1] = w1_size[1] * weight_parallel_group_size + _w1 = module.w1.new_empty(w1_size) # type: ignore + stacked_param_init_helper(_w1, _init_fn_, module._stack_dim) # type: ignore + _w1_local = _w1.chunk(expert_process_group_size, dim=0)[rank] + _w1_local_slice = _w1_local.chunk(weight_parallel_group_size, dim=1)[w_rank] + with torch.no_grad(): + module.w1.copy_(_w1_local_slice) # type: ignore + + # Initialize w2 + w2_size = list(module.w2.shape) # type: ignore + w2_size[0] = w2_size[0] * expert_process_group_size + w2_size[1] = w2_size[1] * weight_parallel_group_size + _w2 = module.w2.new_empty(w2_size) # type: ignore + stacked_param_init_helper(_w2, _init_fn_, module._stack_dim) # type: ignore + _w2_local = _w2.chunk(expert_process_group_size, dim=0)[rank] + _w2_local_slice = _w2_local.chunk(weight_parallel_group_size, dim=1)[w_rank] + with torch.no_grad(): + module.w2.copy_(_w2_local_slice) # type: ignore + if init_div_is_residual is not False: + with torch.no_grad(): + module.w2.div_(div_is_residual) # type: ignore + + +def _normal_init_(std: float, mean: float = 0.0): return partial(torch.nn.init.normal_, mean=mean, std=std) @@ -263,8 +577,8 @@ def small_param_init_fn_( **kwargs: Any, ) -> None: del kwargs # unused, just to capture any extra args from the config - # very close to kaiming normal - # from Transformers without Tears (2019) - Nguyen & Salazar + # Very close to kaiming normal + # From Transformers without Tears (2019) - Nguyen & Salazar std = math.sqrt(2 / (5 * d_model)) _normal_param_init_fn_( module=module, diff --git a/llmfoundry/tokenizers/tiktoken.py b/llmfoundry/tokenizers/tiktoken.py index 1ef2a8cacf..298e1bc984 100644 --- a/llmfoundry/tokenizers/tiktoken.py +++ b/llmfoundry/tokenizers/tiktoken.py @@ -1,5 +1,6 @@ -# Copyright 2022 MosaicML LLM Foundry authors +# Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 + from functools import lru_cache from typing import Any, Dict, List, Optional, Tuple diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 0edbae80a5..d2c3b733c0 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -129,6 +129,17 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): fsdp_config.setdefault('use_orig_params', False) fsdp_config.setdefault('load_monolith_rank0_only', True) + # Set ffn_config.device_mesh to fsdp_config.device_mesh + if fsdp_config is not None and 'device_mesh' in fsdp_config and 'ffn_config' in model_cfg and model_cfg[ + 'ffn_config'].get('ffn_type', None) in {'mb_moe', 'mb_dmoe'}: + # Raise ValueError if not using device mesh with MoE expert parallelism + if fsdp_config['device_mesh'] is None and model_cfg['ffn_config'].get( + 'moe_world_size', 1) > 1: + raise ValueError( + 'device_mesh must be specified in fsdp_config when using MoE with moe_world_size > 1.' + ) + model_cfg.ffn_config.device_mesh = fsdp_config['device_mesh'] + # No mixed precision needed for weights when they're already 16 bits master_dtype = model_cfg.get('master_weights_dtype') small_dtypes = ('bf16', 'fp16', 'float16', 'bfloat16', 'amp_fp16', diff --git a/scripts/data_prep/convert_text_to_mds.py b/scripts/data_prep/convert_text_to_mds.py index df39e38a90..be986fc24d 100644 --- a/scripts/data_prep/convert_text_to_mds.py +++ b/scripts/data_prep/convert_text_to_mds.py @@ -114,6 +114,13 @@ def parse_args() -> Namespace: help='If true, reprocess the input_folder to mds format. Otherwise, ' + 'only reprocess upon changes to the input folder or dataset creation parameters.', ) + parser.add_argument( + '--trust-remote-code', + type=bool, + required=False, + default=False, + help='If true, allows custom code to be executed to load the tokenizer', + ) parsed = parser.parse_args() @@ -124,7 +131,8 @@ def parse_args() -> Namespace: parser.error( 'Cannot set --eos_text with --use_tokenizer_eos. Please specify one.' ) - tokenizer = AutoTokenizer.from_pretrained(parsed.tokenizer) + tokenizer = AutoTokenizer.from_pretrained( + parsed.tokenizer, trust_remote_code=parsed.trust_remote_code) parsed.eos_text = tokenizer.eos_token # now that we have validated them, change BOS/EOS to strings @@ -171,6 +179,7 @@ def get_task_args( bos_text: str, no_wrap: bool, compression: str, + trust_remote_code: bool, ) -> Iterable: """Get download_and_convert arguments split across n_groups. @@ -187,6 +196,7 @@ def get_task_args( bos_text (str): Text to prepend to each example to separate concatenated samples no_wrap: (bool): Whether to let text examples wrap across multiple training examples compression (str): The compression algorithm to use for MDS writing + trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer """ num_objects = len(object_names) objs_per_group = math.ceil(num_objects / n_groups) @@ -202,6 +212,7 @@ def get_task_args( bos_text, no_wrap, compression, + trust_remote_code, ) @@ -223,6 +234,7 @@ def download_and_convert( bos_text: str, no_wrap: bool, compression: str, + trust_remote_code: bool, ): """Downloads and converts text fies to MDS format. @@ -236,6 +248,7 @@ def download_and_convert( bos_text (str): Text to prepend to each example to separate concatenated samples no_wrap: (bool): Whether to let text examples wrap across multiple training examples compression (str): The compression algorithm to use for MDS writing + trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer """ object_store = maybe_create_object_store_from_uri(input_folder) @@ -244,7 +257,8 @@ def download_and_convert( downloading_iter = DownloadingIterable(object_names=file_names, output_folder=tmp_dir, object_store=object_store) - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, trust_remote_code=trust_remote_code) tokenizer.model_max_length = 5000000000 # Hack to prevent warnings from HuggingFace # Use the ConcatTokensDataset from LLM-foundry to concatenate sequences of tokens up @@ -353,6 +367,7 @@ def convert_text_to_mds( processes: int, args_str: str, reprocess: bool, + trust_remote_code: bool, ): """Convert a folder of text files to MDS format. @@ -368,6 +383,7 @@ def convert_text_to_mds( processes (int): The number of processes to use. args_str (str): String representation of the arguments reprocess (bool): Whether to always reprocess the given folder of text files + trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer """ is_remote_output = is_remote_path(output_folder) @@ -396,7 +412,7 @@ def convert_text_to_mds( # Download and convert the text files in parallel args = get_task_args(object_names, local_output_folder, input_folder, processes, tokenizer_name, concat_tokens, eos_text, - bos_text, no_wrap, compression) + bos_text, no_wrap, compression, trust_remote_code) with ProcessPoolExecutor(max_workers=processes) as executor: list(executor.map(download_and_convert_starargs, args)) @@ -405,7 +421,7 @@ def convert_text_to_mds( else: download_and_convert(object_names, local_output_folder, input_folder, tokenizer_name, concat_tokens, eos_text, bos_text, - no_wrap, compression) + no_wrap, compression, trust_remote_code) # Write a done file with the args and object names write_done_file(local_output_folder, args_str, object_names) @@ -462,6 +478,7 @@ def _args_str(original_args: Namespace) -> str: compression=args.compression, processes=args.processes, reprocess=args.reprocess, + trust_remote_code=args.trust_remote_code, args_str=_args_str(args)) except Exception as e: if mosaicml_logger is not None: diff --git a/scripts/inference/hf_generate.py b/scripts/inference/hf_generate.py index 6ac645e5b7..57193136ec 100644 --- a/scripts/inference/hf_generate.py +++ b/scripts/inference/hf_generate.py @@ -206,7 +206,6 @@ def main(args: Namespace) -> None: if device is not None: print(f'Placing model on {device=}...') model.to(device) - model.to(model_dtype) except Exception as e: raise RuntimeError( 'Unable to load HF model. ' + diff --git a/scripts/train/train.py b/scripts/train/train.py index f0a6038dc8..76156d4577 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -80,16 +80,16 @@ def validate_config(cfg: DictConfig): fsdp_config = cfg.get('fsdp_config', None) act_ckpt = fsdp_config.get('activation_checkpointing', False) act_ckpt_reentrant = fsdp_config.get( - 'activation_checkpointing_reentrant', True) - if fsdp_config is not None and act_ckpt == True and act_ckpt_reentrant == False: + 'activation_checkpointing_reentrant', False) + if fsdp_config is not None and act_ckpt == True and act_ckpt_reentrant == True: warnings.warn( '`te.Linear` layers do not support activation_checkpointing with ' - + '`activation_checkpointing_reentrant = False`. ' + - 'Setting cfg.fsdp_config.activation_checkpointing_reentrant=True.' + + '`activation_checkpointing_reentrant = True`. ' + + 'Setting cfg.fsdp_config.activation_checkpointing_reentrant=False.' ) - cfg.fsdp_config.activation_checkpointing_reentrant = True + cfg.fsdp_config.activation_checkpointing_reentrant = False - if 'te' in cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp'): + if cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') == 'te_ln_mlp': warnings.warn( '`te.LayerNormMLP` requires has issues with torch._dynamo. ' + 'Setting `torch._dynamo.config.suppress_errors = True` and falling back to eager.' @@ -101,6 +101,17 @@ def validate_config(cfg: DictConfig): '`load_in_8bit` is only supported for evaluation rather than training.' ) + if cfg.model.get('ffn_config', {}).get('ffn_type', + 'mptmlp') in ('mb_moe', 'mb_dmoe'): + moe_world_size = cfg.model.get('ffn_config', + {}).get('moe_world_size', 1) + use_orig_params = cfg.get('fsdp_config', + {}).get('use_orig_params', True) + if moe_world_size > 1 and not use_orig_params: + raise ValueError( + f'MoEs with expert parallelism (moe_world_size {moe_world_size} > 1) require `use_orig_params=True`.' + ) + def main(cfg: DictConfig) -> Trainer: # Run user provided code if specified @@ -323,6 +334,10 @@ def main(cfg: DictConfig) -> Trainer: 'load_ignore_keys', must_exist=False, default_value=None) + save_ignore_keys: Optional[List[str]] = pop_config(cfg, + 'save_ignore_keys', + must_exist=False, + default_value=None) compile_config: Optional[Dict[str, Any]] = pop_config(cfg, 'compile_config', must_exist=False, @@ -520,11 +535,20 @@ def main(cfg: DictConfig) -> Trainer: ) # Log number of parameters - n_params = sum(p.numel() for p in model.parameters()) - n_trainable_params = sum( - p.numel() for p in model.parameters() if p.requires_grad) + if hasattr(model, 'n_total_params'): + n_params = model.n_total_params + n_trainable_params = n_params # TODO: we currently assume all parameters are trainable. + else: + n_params = sum(p.numel() for p in model.parameters()) + n_trainable_params = sum( + p.numel() for p in model.parameters() if p.requires_grad) + if hasattr(model, 'n_active_params'): + n_active_params = model.n_active_params + else: + n_active_params = n_params logged_cfg.update({ 'n_params': n_params, + 'n_active_params': n_active_params, 'n_trainable_params': n_trainable_params, }) @@ -580,6 +604,7 @@ def main(cfg: DictConfig) -> Trainer: load_weights_only=load_weights_only, load_strict_model_weights=load_strict_model_weights, load_ignore_keys=load_ignore_keys, + save_ignore_keys=save_ignore_keys, autoresume=autoresume, python_log_level=python_log_level, dist_timeout=dist_timeout, diff --git a/scripts/train/yamls/pretrain/testing-moe.yaml b/scripts/train/yamls/pretrain/testing-moe.yaml new file mode 100644 index 0000000000..eea2b999b7 --- /dev/null +++ b/scripts/train/yamls/pretrain/testing-moe.yaml @@ -0,0 +1,117 @@ +data_local: ./my-copy-c4 +data_remote: # If blank, files must be present in data_local +max_seq_len: 128 +global_seed: 17 + +# Run Name +run_name: # If left blank, will be read from env var $RUN_NAME + +# Model +model: + name: mpt_causal_lm + init_device: meta + d_model: 128 + ffn_config: + ffn_type: mb_dmoe + memory_optimized_mlp: true + moe_lbl_in_fp32: false + moe_loss_weight: 0.01 + moe_num_experts: 4 + moe_top_k: 2 + moe_world_size: 1 + moe_weight_parallelism: false + uniform_expert_assignment: false + n_heads: 2 + n_layers: 2 + expansion_ratio: 1 + max_seq_len: ${max_seq_len} + vocab_size: 50368 + attn_config: + attn_impl: torch + loss_fn: torch_crossentropy + +# Tokenizer +tokenizer: + name: EleutherAI/gpt-neox-20b + kwargs: + model_max_length: ${max_seq_len} + +# Dataloaders +train_loader: + name: text + dataset: + local: ${data_local} + remote: ${data_remote} + split: train + shuffle: true + max_seq_len: ${max_seq_len} + shuffle_seed: ${global_seed} + drop_last: true + num_workers: 8 + +eval_loader: + name: text + dataset: + local: ${data_local} + remote: ${data_remote} + split: val + shuffle: false + max_seq_len: ${max_seq_len} + shuffle_seed: ${global_seed} + drop_last: false + num_workers: 8 + +# Optimization +scheduler: + name: cosine_with_warmup + t_warmup: 100ba + alpha_f: 0.1 + +optimizer: + name: decoupled_adamw + lr: 6.0e-4 + betas: + - 0.9 + - 0.95 + eps: 1.0e-08 + weight_decay: 0.0 + +algorithms: + gradient_clipping: + clipping_type: norm + clipping_threshold: 1.0 + +max_duration: 200ba +eval_interval: 100ba +eval_first: false +eval_subset_num_batches: -1 +global_train_batch_size: 256 + +# System +seed: ${global_seed} +device_eval_batch_size: 16 +device_train_microbatch_size: 16 +# device_train_microbatch_size: auto +precision: amp_bf16 + +# FSDP +fsdp_config: + sharding_strategy: FULL_SHARD + mixed_precision: PURE + activation_checkpointing: false + activation_checkpointing_reentrant: false + activation_cpu_offload: false + limit_all_gathers: true + verbose: false + +# Logging +progress_bar: false +log_to_console: true +console_log_interval: 1ba + +callbacks: + speed_monitor: + window_size: 10 + lr_monitor: {} + memory_monitor: {} + runtime_estimator: {} diff --git a/setup.py b/setup.py index 79511eeca3..086e759384 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ 'mlflow>=2.10,<3', 'accelerate>=0.25,<0.26', # for HF inference `device_map` 'transformers>=4.39.3,<4.40', - 'mosaicml-streaming>=0.7.4,<0.8', + 'mosaicml-streaming>=0.7.5,<0.8', 'torch>=2.2.1,<2.3', 'datasets>=2.16,<2.17', 'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data @@ -117,8 +117,15 @@ 'openai==1.3.8', 'tiktoken==0.4.0', ] -extra_deps['all-cpu'] = set( - dep for key, deps in extra_deps.items() for dep in deps if 'gpu' not in key) + +extra_deps['megablocks'] = [ + 'megablocks==0.5.1', + 'grouped-gemm==0.1.4', +] + +extra_deps['all-cpu'] = set(dep for key, deps in extra_deps.items() + for dep in deps + if 'gpu' not in key and 'megablocks' not in key) extra_deps['all'] = set(dep for key, deps in extra_deps.items() for dep in deps if key not in {'gpu-flash2', 'all-cpu'}) extra_deps['all-flash2'] = set(dep for key, deps in extra_deps.items() diff --git a/tests/a_scripts/data_prep/test_convert_text_to_mds.py b/tests/a_scripts/data_prep/test_convert_text_to_mds.py index e458cb1dfc..bd96de695c 100644 --- a/tests/a_scripts/data_prep/test_convert_text_to_mds.py +++ b/tests/a_scripts/data_prep/test_convert_text_to_mds.py @@ -106,6 +106,7 @@ def call_convert_text_to_mds() -> None: processes=processes, args_str='Namespace()', reprocess=False, + trust_remote_code=False, ) call_convert_text_to_mds() @@ -195,6 +196,7 @@ def call_convert_text_to_mds(reprocess: bool): processes=1, args_str='Namespace()', reprocess=reprocess, + trust_remote_code=False, ) # Create input text data @@ -234,6 +236,7 @@ def test_input_folder_not_exist(tmp_path: pathlib.Path): processes=1, args_str='Namespace()', reprocess=False, + trust_remote_code=False, ) diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 7b4ef1e058..061227d8a4 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -18,14 +18,17 @@ from composer.utils import dist, get_device from omegaconf import DictConfig from omegaconf import OmegaConf as om +from torch.distributed._tensor.api import DTensor from torch.utils.data import DataLoader from transformers import PreTrainedModel, PreTrainedTokenizerBase from llmfoundry.callbacks import HuggingFaceCheckpointer from llmfoundry.callbacks.hf_checkpointer import _maybe_get_license_filename from llmfoundry.data.finetuning import build_finetuning_dataloader +from llmfoundry.models.mpt import MPTConfig from llmfoundry.utils.builders import (build_composer_model, build_optimizer, build_tokenizer) +from llmfoundry.utils.config_utils import process_init_device from scripts.inference.convert_composer_to_hf import convert_composer_to_hf from tests.data_utils import make_tiny_ft_dataset @@ -191,9 +194,18 @@ def check_hf_tokenizer_equivalence(tokenizer1: PreTrainedTokenizerBase, assert tokenizer1.__dict__ == tokenizer2.__dict__ +def remove_moe_world_size(config: MPTConfig): + if hasattr(config, 'ffn_config'): + if 'moe_world_size' in config.ffn_config: + config.ffn_config.pop('moe_world_size') + + def check_hf_model_equivalence(model1: PreTrainedModel, model2: PreTrainedModel, just_lora: bool = False): + remove_moe_world_size(model1.config) + remove_moe_world_size(model2.config) + expected_model_config_dict = model1.config.to_dict() new_model_config_dict = model2.config.to_dict() @@ -225,6 +237,7 @@ def check_hf_model_equivalence(model1: PreTrainedModel, assert torch.equal(p1.cpu(), p2.cpu()) +# TODO(GRT-2435): Change to fixture def delete_transformers_cache(): # Only delete the files on local rank 0, otherwise race conditions are created if not dist.get_local_rank() == 0: @@ -421,6 +434,35 @@ def _get_model_and_tokenizer(model: str, max_seq_len: int, 'tie_word_embeddings': tie_word_embeddings, } tokenizer_name = 'EleutherAI/gpt-neox-20b' + elif model == 'mptmoe': + # Test export on moe_world_size 1 + model_cfg = { + 'name': 'mpt_causal_lm', + 'init_device': 'cpu', + 'd_model': 128, + 'n_heads': 2, + 'n_layers': 2, + 'expansion_ratio': 1, + 'ffn_config': { + 'ffn_type': 'mb_dmoe', + 'memory_optimized_mlp': True, + 'moe_lbl_in_fp32': False, + 'moe_loss_weight': 0.01, + 'moe_num_experts': 4, + 'moe_top_k': 2, + 'moe_world_size': 1, + 'moe_weight_parallelism': False, + 'uniform_expert_assignment': False, + }, + 'max_seq_len': max_seq_len, + 'vocab_size': 50368, + 'attn_config': { + 'attn_impl': 'torch', + }, + 'loss_fn': 'torch_crossentropy', + 'no_bias': True, + } + tokenizer_name = 'EleutherAI/gpt-neox-20b' elif model == 'neo': assert tie_word_embeddings is None model_cfg = { @@ -645,6 +687,7 @@ def _assert_checkpoint_equivalence(tmp_path: pathlib.Path, [ ('mpt', True, None), ('mpt', False, None), + ('mptmoe', None, None), ('neo', None, None), ('llama2', None, None), ('llama2', None, { @@ -680,6 +723,8 @@ def test_huggingface_conversion_callback( expected_normal_checkpoints: int, peft_config: Optional[dict], ): + if model == 'mptmoe' and fsdp_state_dict_type is None: + pytest.skip('mptmoe requires FSDP') delete_transformers_cache() dist.initialize_dist(get_device('gpu')) @@ -697,7 +742,7 @@ def test_huggingface_conversion_callback( precision=precision_str, mlflow_registered_model_name='dummy-registered-name') - # get small version of each model + # Get small version of each model model_cfg, tokenizer_name = _get_model_and_tokenizer( model, max_seq_len, tie_word_embeddings) assert model_cfg is not None @@ -781,9 +826,12 @@ def test_huggingface_conversion_callback( delete_transformers_cache() +# TODO(GRT-2431): Refactor as enums @pytest.mark.parametrize( 'model,tie_word_embeddings', - [('mpt', True), ('mpt', False), ('neo', None), ('llama2', None)], + [('mpt', True), ('mpt', False), + pytest.param('mptmoe', None, marks=pytest.mark.gpu), ('neo', None), + ('llama2', None)], ) def test_convert_and_generate(model: str, tie_word_embeddings: bool, tmp_path: pathlib.Path): @@ -794,6 +842,9 @@ def test_convert_and_generate(model: str, tie_word_embeddings: bool, om_cfg = get_config( conf_path='scripts/train/yamls/pretrain/testing.yaml') om_cfg['tie_word_embeddings'] = tie_word_embeddings + elif model == 'mptmoe': + om_cfg = get_config( + conf_path='scripts/train/yamls/pretrain/testing-moe.yaml') elif model == 'neo': assert tie_word_embeddings is None om_cfg = get_config( @@ -824,7 +875,8 @@ def test_convert_and_generate(model: str, tie_word_embeddings: bool, cfg=om_cfg['model'], tokenizer=tokenizer, ) - trainer = Trainer(model=original_model, device='cpu') + trainer = Trainer(model=original_model, + device='cpu' if not model == 'mptmoe' else 'gpu') trainer.save_checkpoint(os.path.join(tmp_path, 'checkpoint.pt')) args = Namespace(composer_path=os.path.join(tmp_path, 'checkpoint.pt'), @@ -845,8 +897,15 @@ def test_convert_and_generate(model: str, tie_word_embeddings: bool, tokenizer = transformers.AutoTokenizer.from_pretrained( os.path.join(tmp_path, 'hf-output-folder'), trust_remote_code=True) - output = loaded_model.generate(tokenizer('hello', - return_tensors='pt')['input_ids'], + device = 'cuda' if model == 'mptmoe' else 'cpu' + precision = torch.bfloat16 if model == 'mptmoe' else torch.float32 + original_model.to(device) + original_model.to(precision) + loaded_model.to(device) + loaded_model.to(precision) + + output = loaded_model.generate(tokenizer( + 'hello', return_tensors='pt')['input_ids'].to(device), max_new_tokens=1) assert output.shape == (1, 2 + (1 if model == 'llama2' else 0)) @@ -863,16 +922,21 @@ def test_convert_and_generate(model: str, tie_word_embeddings: bool, delete_transformers_cache() +@pytest.mark.parametrize('conf_path', [ + 'scripts/train/yamls/pretrain/testing.yaml', + pytest.param('scripts/train/yamls/pretrain/testing-moe.yaml', + marks=pytest.mark.gpu), +]) @pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_convert_and_generate_meta(tie_word_embeddings: str, - tmp_path: pathlib.Path): + tmp_path: pathlib.Path, conf_path: str): delete_transformers_cache() from composer.utils import dist gathered_paths = dist.all_gather_object(tmp_path) tmp_path_gathered = gathered_paths[0] - om_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml') + om_cfg = get_config(conf_path=conf_path) om_cfg['model']['init_device'] = 'cpu' om_cfg['tie_word_embeddings'] = tie_word_embeddings @@ -883,7 +947,8 @@ def test_convert_and_generate_meta(tie_word_embeddings: str, cfg=om_cfg['model'], tokenizer=tokenizer, ) - trainer = Trainer(model=original_model, device='cpu') + trainer = Trainer(model=original_model, + device='cpu' if not 'moe' in conf_path else 'gpu') trainer.save_checkpoint(os.path.join(tmp_path_gathered, 'checkpoint.pt')) # patch in the meta device for testing @@ -915,8 +980,15 @@ def test_convert_and_generate_meta(tie_word_embeddings: str, os.path.join(tmp_path_gathered, 'hf-output-folder'), trust_remote_code=True) - output = loaded_model.generate(tokenizer('hello', - return_tensors='pt')['input_ids'], + device = 'cuda' if 'moe' in conf_path else 'cpu' + precision = torch.bfloat16 if 'moe' in conf_path else torch.float32 + original_model.to(device) + original_model.to(precision) + loaded_model.to(device) + loaded_model.to(precision) + + output = loaded_model.generate(tokenizer( + 'hello', return_tensors='pt')['input_ids'].to(device), max_new_tokens=1) assert output.shape == (1, 2) @@ -933,6 +1005,253 @@ def test_convert_and_generate_meta(tie_word_embeddings: str, delete_transformers_cache() +@pytest.mark.world_size(4) +@pytest.mark.gpu +@pytest.mark.parametrize('num_experts', [2, 4, 8]) +@pytest.mark.parametrize('sharding_strategy', ['FULL_SHARD', 'HYBRID_SHARD']) +def test_mptmoe_huggingface_conversion_callback( + tmp_path: pathlib.Path, + num_experts: int, + sharding_strategy: str, + hf_save_interval: str = '1ba', + save_interval: str = '1ba', + max_duration: str = '1ba', + expected_hf_checkpoints: int = 1, + expected_normal_checkpoints: int = 1, +): + + delete_transformers_cache() + + dist.initialize_dist(get_device('gpu')) + if dist.get_world_size() != 4: + pytest.skip('This test requires 4 GPUs') + + max_seq_len = 16 + device_batch_size = 1 + dataset_size = 2 + precision_str = 'float32' + precision = torch.float32 + batches_per_epoch = math.ceil(dataset_size / (device_batch_size * 2)) + + checkpointer_callback = HuggingFaceCheckpointer( + save_folder=os.path.join(tmp_path, 'checkpoints'), + save_interval=hf_save_interval, + precision=precision_str, + ) + + # get small version of each model + model_cfg = None + tokenizer_name = None + + # Test export on moe_world_size 1 + model_cfg = { + 'name': 'mpt_causal_lm', + 'init_device': 'cpu', + 'd_model': 128, + 'n_heads': 2, + 'n_layers': 2, + 'expansion_ratio': 1, + 'ffn_config': { + 'ffn_type': + 'mb_dmoe', + 'memory_optimized_mlp': + True, + 'moe_lbl_in_fp32': + False, + 'moe_loss_weight': + 0.01, + 'moe_num_experts': + num_experts, + 'moe_top_k': + 2, + 'moe_world_size': + 2, + 'moe_weight_parallelism': + False, + 'uniform_expert_assignment': + True, + 'mlp_impl': + 'grouped', + 'mlp_type': + 'glu', + 'device_mesh': [1, 2] if sharding_strategy == 'HYBRID_SHARD' else [ + 2, + ], + }, + 'precision': 'amp_bf16', + 'max_seq_len': max_seq_len, + 'vocab_size': 50368, + 'attn_config': { + 'attn_impl': 'torch', + }, + 'loss_fn': 'torch_crossentropy', + 'no_bias': True, + } + tokenizer_name = 'EleutherAI/gpt-neox-20b' + assert model_cfg is not None + assert tokenizer_name is not None + model_cfg = om.create(model_cfg) + + fsdp_config = { + 'sharding_strategy': sharding_strategy, + 'mixed_precision': 'PURE', + 'activation_checkpointing': False, + 'activation_checkpointing_reentrant': False, + 'activation_cpu_offload': False, + 'limit_all_gathers': True, + 'device_mesh': [1, 4] if sharding_strategy == 'HYBRID_SHARD' else [ + 4, + ], + 'use_orig_params': True, + } + + tiny_dataset_folder_path = os.path.join(os.getcwd(), 'test-ift-data-small') + tiny_dataset_path = os.path.join(tiny_dataset_folder_path, 'train.jsonl') + if dist.get_global_rank() == 0: + make_tiny_ft_dataset(path=tiny_dataset_path, size=dataset_size) + + dataloader_cfg = { + 'name': 'finetuning', + 'dataset': { + 'hf_name': tiny_dataset_folder_path, + 'split': 'train', + 'max_seq_len': max_seq_len, + 'decoder_only_format': True, + 'allow_pad_trimming': False, + 'packing_ratio': None, + 'shuffle': True, + }, + 'drop_last': False, + 'num_workers': 0, + 'pin_memory': False, + 'prefetch_factor': None, + 'persistent_workers': False, + 'timeout': 0 + } + + dataloader_cfg = om.create(dataloader_cfg) + + tokenizer = build_tokenizer( + tokenizer_name=tokenizer_name, + tokenizer_kwargs={'model_max_length': max_seq_len}, + ) + + train_dataloader = build_finetuning_dataloader( + dataloader_cfg, + tokenizer, + device_batch_size, + ) + + optimizer_config = { + 'name': 'decoupled_adamw', + 'lr': 6e-4, + 'betas': [0.9, 0.95], + 'eps': 1e-8, + 'weight_decay': 0.0, + } + optimizer_name = optimizer_config.pop('name') + + init_context = process_init_device(model_cfg, fsdp_config) + original_model = build_composer_model( + name=model_cfg.name, + cfg=model_cfg, + tokenizer=tokenizer, + init_context=init_context, + ) + + optimizer = build_optimizer(original_model, optimizer_name, + optimizer_config) + trainer = Trainer( + model=original_model, + device='gpu', + fsdp_config=fsdp_config, + train_dataloader=train_dataloader, + save_folder=os.path.join(tmp_path, 'checkpoints'), + save_interval=save_interval, + max_duration=max_duration, + callbacks=[checkpointer_callback], + optimizers=optimizer, + save_latest_filename=None, + precision=model_cfg.pop('precision', None), + save_weights_only=True, + ) + trainer.fit() + #self.state.outputs = self.state.model(self.state.batch) + batch = trainer.state.batch + model_output_logits = trainer.state.model(batch).logits + + # summon full params to check equivalence + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + with FSDP.summon_full_params(trainer.state.model, + writeback=False, + recurse=True): + loaded_model = None + loaded_tokenizer = None + # Only rank zero is saving the huggingface checkpoints, so only check + # for equivalence on rank zero + if dist.get_global_rank() == 0: + normal_checkpoints = [ + name + for name in os.listdir(os.path.join(tmp_path, 'checkpoints')) + if name != 'huggingface' + ] + huggingface_checkpoints = [ + name for name in os.listdir( + os.path.join(tmp_path, 'checkpoints', 'huggingface')) + ] + assert len(normal_checkpoints) == expected_normal_checkpoints + assert len(huggingface_checkpoints) == expected_hf_checkpoints + + # Patch flash_attn package to be empty to simulate loading the model in + # an environment without flash atttention installed + with patch.dict('sys.modules', {'flash_attn': None}): + # Load the last huggingface checkpoint + loaded_model = transformers.AutoModelForCausalLM.from_pretrained( + os.path.join(tmp_path, 'checkpoints', 'huggingface', + f'ba1'), + trust_remote_code=True, + ) + + # Check that the loaded model has the correct precision, and then set it back + # to the original for the equivalence check + assert loaded_model.config.torch_dtype == precision + loaded_model.config.torch_dtype = original_model.model.config.torch_dtype + + loaded_tokenizer = transformers.AutoTokenizer.from_pretrained( + os.path.join(tmp_path, 'checkpoints', 'huggingface', + f'ba{batches_per_epoch}'), + trust_remote_code=True, + ) + for n, p in trainer.state.model.model.named_parameters(): + if isinstance(p, DTensor): + submodule_name, param_name = '.'.join( + n.split('.')[:-1]), n.split('.')[-1] + submodule = trainer.state.model.model.get_submodule( + submodule_name) + param_tensor = p.full_tensor() + param = torch.nn.Parameter(param_tensor) + submodule.register_parameter(param_name, param) + + if dist.get_global_rank() == 0: + check_hf_model_equivalence(trainer.state.model.model, loaded_model) + check_hf_tokenizer_equivalence(tokenizer, loaded_tokenizer) + + # Check output equivalence + loaded_model = loaded_model.cuda().bfloat16() # type: ignore + loaded_model_logits = loaded_model( + input_ids=batch.get('input_ids', None), + attention_mask=batch.get('attention_mask', None), + prefix_mask=batch.get('bidirectional_mask', None), + sequence_id=batch.get('sequence_id', None), + inputs_embeds=batch.get('inputs_embeds', None), + ).logits + assert torch.equal(loaded_model_logits, model_output_logits) + + dist.barrier() + + delete_transformers_cache() + + @pytest.mark.parametrize( 'license_file_name', ['LICENSE', 'LICENSE.txt', 'license', 'license.md', None]) diff --git a/tests/a_scripts/train/test_train.py b/tests/a_scripts/train/test_train.py index 68ed9d421c..ff885ac735 100644 --- a/tests/a_scripts/train/test_train.py +++ b/tests/a_scripts/train/test_train.py @@ -1,6 +1,8 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 + import copy +import os import pathlib from typing import Optional @@ -9,9 +11,10 @@ from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om -from scripts.train.train import main # noqa: E402 +from scripts.train.train import main, validate_config # noqa: E402 from tests.data_utils import (create_arxiv_dataset, create_c4_dataset_xxsmall, gpt_tiny_cfg) +from tests.fixtures.autouse import REPO_DIR @pytest.mark.parametrize('averages', [{ @@ -144,6 +147,23 @@ def test_train_multi_eval(tmp_path: pathlib.Path): tuple) +def test_validate_config(): + conf_path: str = os.path.join( + REPO_DIR, + 'scripts/train/yamls/pretrain/testing-moe.yaml', + ) + with open(conf_path) as f: + test_cfg: DictConfig = om.load(f) # type: ignore + test_cfg.model.ffn_config.moe_world_size = 4 + test_cfg.fsdp_config.use_orig_params = False + with pytest.raises( + ValueError, + match= + 'MoEs with expert parallelism (.*) require `use_orig_params=True`.' + ): + validate_config(test_cfg) + + def test_eval_metrics_with_no_train_metrics(tmp_path: pathlib.Path): """Test using use_train_metrics=False does not disable eval metrics.""" c4_dataset_name = create_c4_dataset_xxsmall(tmp_path) diff --git a/tests/callbacks/test_mbmoe_tok_per_expert_callback.py b/tests/callbacks/test_mbmoe_tok_per_expert_callback.py new file mode 100644 index 0000000000..79a625b4e4 --- /dev/null +++ b/tests/callbacks/test_mbmoe_tok_per_expert_callback.py @@ -0,0 +1,11 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from llmfoundry.utils.builders import build_callback + + +def test_mbmoe_tok_per_expert_builds(): + """Test that the callback can be built.""" + callback = build_callback('mbmoe_tok_per_expert') + assert callback is not None + assert callback.__class__.__name__ == 'MegaBlocksMoE_TokPerExpert' diff --git a/tests/models/layers/test_dmoe.py b/tests/models/layers/test_dmoe.py new file mode 100644 index 0000000000..9c15745793 --- /dev/null +++ b/tests/models/layers/test_dmoe.py @@ -0,0 +1,263 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import copy +from contextlib import nullcontext +from functools import partial +from typing import List, Optional + +import pytest +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torch.optim as optim +from torch.distributed._tensor import DTensor, Placement, Replicate, Shard +from torch.distributed._tensor.device_mesh import init_device_mesh +from torch.distributed.checkpoint.state_dict import (StateDictOptions, + get_model_state_dict) +from torch.distributed.tensor.parallel.ddp import _pre_dp_module_transform +from torch.nn.parallel import DistributedDataParallel as DDP + +from llmfoundry.models.layers.dmoe import dMoE +from llmfoundry.models.layers.ffn import dtensorify_param +from llmfoundry.models.mpt.configuration_mpt import MPTConfig +from llmfoundry.models.mpt.modeling_mpt import MPTForCausalLM + +try: + import megablocks + is_megablocks_imported = True +except ModuleNotFoundError: + is_megablocks_imported = False + + +def _get_all_inputs( + input_shape: List[int], + dtype: Optional[torch.dtype], +): + world_size: int = dist.get_world_size() + rank: int = dist.get_rank() + device: torch.device = torch.device(f'cuda:{rank}') + all_inputs = [] + for _ in range(world_size): + all_inputs.append(torch.rand( + input_shape, + device=device, + dtype=dtype, + )) + return all_inputs + + +def _get_torch_dtype(fp16: bool, bf16: bool) -> Optional[torch.dtype]: + if fp16: + return torch.float16 + elif bf16: + return torch.bfloat16 + return None + + +@pytest.mark.skipif(not is_megablocks_imported, + reason='This test needs megablocks module') +@pytest.mark.gpu +@pytest.mark.world_size(2) +@pytest.mark.parametrize('moe_num_experts', [8]) +@pytest.mark.parametrize('mlp_type', ['glu', 'mlp']) +@pytest.mark.parametrize('moe_world_size', [1, 2]) +@pytest.mark.parametrize('two_d_input', [True, False]) +def test_dmoe(moe_num_experts: int, mlp_type: str, moe_world_size: int, + two_d_input: bool): + # Generate inputs + rank = dist.get_rank() + batch_size = 2 + seq_len = 3 + hidden_size = 128 + if two_d_input: + input_shape = [batch_size * seq_len, hidden_size] + else: + input_shape = [batch_size, seq_len, hidden_size] + fp16 = False + bf16 = True + dtype = _get_torch_dtype(fp16, bf16) + x = _get_all_inputs(input_shape, dtype)[rank] + + # Construct DDP torch dMoE + device = torch.device(f'cuda:{dist.get_rank()}') + common_args = { + 'hidden_size': hidden_size, + 'ffn_hidden_size': hidden_size, + 'moe_top_k': 2, + 'activation_fn': partial(F.gelu, approximate='none'), + 'moe_jitter_eps': 0.0, # Disable randomiztion + 'moe_normalize_expert_weights': 1, + 'uniform_expert_assignment': False, + 'bias': False, + 'device': device, + 'moe_num_experts': moe_num_experts, + 'mlp_type': mlp_type, + } + + torch_dmoe = dMoE(**common_args).to(device, dtype=dtype) + torch_dmoe = DDP( + torch_dmoe, + device_ids=[rank], + ) + torch_dmoe_optimizer = optim.SGD(torch_dmoe.parameters(), lr=0.1) + + # Construct TP MB dMoE + mp_dmoe_args = copy.deepcopy(common_args) + extra_args = { + 'fp16': fp16, + 'bf16': bf16, + 'init_method': partial(torch.nn.init.uniform_, a=-1.0, b=1.0), + } + device_mesh = None + if moe_world_size > 1: + world_size = dist.get_world_size() + assert world_size % moe_world_size == 0 + moe_dp_dim = world_size // moe_world_size + device_mesh = init_device_mesh( + 'cuda', + (moe_dp_dim, moe_world_size), + mesh_dim_names=('weight_parallel', 'expert_parallel'), + ) + expert_parallel_group = device_mesh['expert_parallel'].get_group(0) + extra_args.update( + { + 'moe_expert_model_parallelism': True, + 'expert_parallel_group': expert_parallel_group, + },) + mp_dmoe_args.update(extra_args) + args = megablocks.layers.arguments.Arguments(**mp_dmoe_args,) + mb_dmoe = megablocks.layers.dmoe.dMoE(args).to(device) + mb_dmoe.router = DDP(mb_dmoe.router, device_ids=[rank]) + + if moe_world_size > 1: + assert device_mesh is not None + two_d_placements: List[Placement] = [Replicate(), Shard(0)] + dtensorified_params = [( + name, + dtensorify_param( + param=parameter, + mesh=device_mesh, + placements=two_d_placements, + ), + ) for name, parameter in mb_dmoe.experts.mlp.named_parameters()] + tp_names = [] + for name, dtensorified_param in dtensorified_params: + mb_dmoe.experts.mlp.register_parameter(name, dtensorified_param) + tp_names.append('experts.mlp.' + name) + + _pre_dp_module_transform(mb_dmoe.experts.mlp) + + dp_pg = device_mesh['weight_parallel'].get_group(0) + mb_dmoe.experts = DDP(mb_dmoe.experts, process_group=dp_pg) + + # Copy mb_dmoe's parameters to torch_dmoe + mb_dmoe_state_dict = get_model_state_dict(mb_dmoe, + options=StateDictOptions( + full_state_dict=True,)) + for key, t in mb_dmoe_state_dict.items(): + if key in tp_names: + dtensor_full = DTensor.from_local( + t, # pyright: ignore[reportGeneralTypeIssues] + device_mesh=device_mesh, + placements=two_d_placements, + ).full_tensor() + + mb_dmoe_state_dict[key] = dtensor_full + else: + mb_dmoe.experts = DDP(mb_dmoe.experts, device_ids=[rank]) + mb_dmoe_state_dict = get_model_state_dict(mb_dmoe, + options=StateDictOptions( + full_state_dict=True,)) + mb_dmoe_optimizer = optim.SGD(mb_dmoe.parameters(), lr=0.1) + + # Load mb_dmoe state dict to torch dmoe + torch_dmoe.module.load_state_dict(mb_dmoe_state_dict, strict=True) + + # Run train_step check + torch_y = torch_dmoe(x) + mb_y = mb_dmoe(x) + + torch_y.sum().backward() + mb_y.sum().backward() + torch_dmoe_optimizer.step() + mb_dmoe_optimizer.step() + + torch_y = torch_dmoe(x) + mb_y = mb_dmoe(x) + torch.testing.assert_close(torch_y, mb_y) + + +@pytest.mark.skipif(not is_megablocks_imported, + reason='This test needs megablocks module') +@pytest.mark.gpu +@pytest.mark.parametrize('seqlen', [512]) +@pytest.mark.parametrize('mlp_type', ['glu', 'mlp']) +@pytest.mark.parametrize('precision', ['bf16', 'fp32']) +def test_fwd_equal_dmoe(seqlen: int, precision: str, mlp_type: str): + mb_dmoe_config = MPTConfig(d_model=1024, + n_heads=32, + n_layers=1, + learned_pos_emb=False, + max_seq_len=2048, + vocab_size=100, + no_bias=True, + fuse_norm_attn_norm=True, + tie_word_embeddings=False, + attn_config=dict( + attn_type='grouped_query_attention', + attn_impl='torch', + attn_pdrop=0.0, + clip_qkv=8.0, + kv_n_heads=8, + rope=True, + rope_theta=10000.0, + ), + ffn_config=dict( + ffn_type='mb_dmoe', + fc_type='torch', + mlp_type=mlp_type, + moe_world_size=1, + ffn_act_fn={'name': 'silu'}, + ffn_hidden_size=1792, + moe_num_experts=16, + moe_top_k=4, + moe_jitter_eps=0.0, + moe_loss_weight=0.05, + moe_normalize_expert_weights=1.0, + uniform_expert_assignment=False, + )) + device = 'cuda:0' + if precision == 'fp32': + dtype = torch.float32 + context = nullcontext() + elif precision == 'bf16': + dtype = torch.bfloat16 + context = torch.autocast('cuda', torch.bfloat16) + else: + raise ValueError(f'Invalid {precision=}') + + torch_dmoe_config = copy.deepcopy(mb_dmoe_config) + torch_dmoe_config.ffn_config['ffn_type'] = 'torch_dmoe' + + mb_dmoe_model = MPTForCausalLM(mb_dmoe_config).to(device=device, + dtype=dtype) + torch_dmoe_model = MPTForCausalLM(torch_dmoe_config).to(device=device, + dtype=dtype) + + # set same state dicts + torch_dmoe_model.load_state_dict(mb_dmoe_model.state_dict()) + + # tokens + token_ids = torch.randint( + 0, + mb_dmoe_config.vocab_size, + (1, seqlen), + device=device, + dtype=torch.long, + ) + + with context: + mpt_logits = mb_dmoe_model(token_ids).logits + db_logits = torch_dmoe_model(token_ids).logits + assert torch.allclose(mpt_logits, db_logits, rtol=0.01, atol=0.01) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 7bd8292151..402698cb27 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -53,7 +53,8 @@ def _load_tokenizer_cfg(cfg: DictConfig) -> Dict: return config -def get_objs(conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml'): +def _get_objs(request: pytest.FixtureRequest, + conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml'): warnings.filterwarnings( action='ignore', message='Torchmetrics v0.9 introduced a new argument class property') @@ -64,16 +65,19 @@ def get_objs(conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml'): fsdp_config = om.to_container(fsdp_config, resolve=True) if fsdp_config else None + # Check if we are running on GPU + is_gpu = False + for item in request.session.items: + is_gpu |= item.get_closest_marker('gpu') is not None + # Build Model # For fast initialization, use `meta` device print('Initializing model...') - device = 'cpu' - test_cfg.precision = 'fp32' + device = 'cuda' if is_gpu else 'cpu' + test_cfg.precision = 'amp_bf16' if is_gpu else 'fp32' test_cfg.model.attn_config = { 'attn_impl': 'torch', } - # device = 'cuda' - # test_cfg.precision = 'amp' test_cfg.model.init_device = device test_cfg.device = device @@ -151,9 +155,13 @@ def gen_random_enc_dec_batch(batch_size: int, vocab_size: int, max_seq_len: int, return batch -def test_full_forward_and_backward(batch_size: int = 2): - test_cfg, model, optimizer = get_objs( - conf_path='scripts/train/yamls/pretrain/testing.yaml') +@pytest.mark.parametrize('conf_path', [ + 'scripts/train/yamls/pretrain/testing.yaml', +]) +def test_full_forward_and_backward(request: pytest.FixtureRequest, + conf_path: str, + batch_size: int = 2): + test_cfg, model, optimizer = _get_objs(request=request, conf_path=conf_path) batch = gen_random_batch(batch_size, test_cfg) @@ -170,9 +178,10 @@ def test_full_forward_and_backward(batch_size: int = 2): assert not torch.equal(original_params, updated_params) -def test_full_forward_and_backward_with_inputs_embeds(batch_size: int = 2): - test_cfg, model, optimizer = get_objs( - conf_path='scripts/train/yamls/pretrain/testing.yaml') +def test_full_forward_and_backward_with_inputs_embeds( + request: pytest.FixtureRequest, batch_size: int = 2): + test_cfg, model, optimizer = _get_objs( + request=request, conf_path='scripts/train/yamls/pretrain/testing.yaml') batch = gen_random_batch(batch_size, test_cfg, inputs=['inputs_embeds']) @@ -188,9 +197,10 @@ def test_full_forward_and_backward_with_inputs_embeds(batch_size: int = 2): @pytest.mark.parametrize('inputs', [[], ['input_ids', 'inputs_embeds']]) -def test_invalid_inputs_embeds_input_ids_combinations(inputs: List[str]): - test_cfg, model, _ = get_objs( - conf_path='scripts/train/yamls/pretrain/testing.yaml') +def test_invalid_inputs_embeds_input_ids_combinations( + request: pytest.FixtureRequest, inputs: List[str]): + test_cfg, model, _ = _get_objs( + request=request, conf_path='scripts/train/yamls/pretrain/testing.yaml') batch = gen_random_batch(2, test_cfg, inputs=inputs) @@ -199,9 +209,15 @@ def test_invalid_inputs_embeds_input_ids_combinations(inputs: List[str]): _ = model(batch) -def test_attention_mechanism(batch_size: int = 2): - test_cfg, model, _ = get_objs( - conf_path='scripts/train/yamls/pretrain/testing.yaml') +@pytest.mark.parametrize('conf_path', [ + 'scripts/train/yamls/pretrain/testing.yaml', + pytest.param('scripts/train/yamls/pretrain/testing-moe.yaml', + marks=pytest.mark.gpu), +]) +def test_attention_mechanism(request: pytest.FixtureRequest, + conf_path: str, + batch_size: int = 2): + test_cfg, model, _ = _get_objs(request=request, conf_path=conf_path) batch = gen_random_batch(batch_size, test_cfg) @@ -217,43 +233,45 @@ def test_attention_mechanism(batch_size: int = 2): pos = torch.arange(0, S, dtype=torch.long, device=input_ids.device).unsqueeze(0) - tok_emb = model.model.transformer.wte(input_ids) - pos_emb = model.model.transformer.wpe(pos) - x = model.model.transformer.emb_drop(tok_emb + pos_emb) - - # basically the attention mask should be a tensor shape (bsz, seqlen, seqlen) - # wih -inf along the upper triangle as well as wherever there are any pad tokens - # and with 0 everywhere else - expected_zerod_weights = nn.Transformer.generate_square_subsequent_mask(test_cfg.max_seq_len)\ - .reshape(1, test_cfg.max_seq_len, test_cfg.max_seq_len) - expected_zerod_weights = torch.isneginf( - torch.cat(batch_size * [expected_zerod_weights])) - torch_key_padding = torch.cat( # type: ignore - test_cfg.max_seq_len * - [(~attention_mask).reshape(batch_size, 1, test_cfg.max_seq_len)], - axis=1) - expected_zerod_weights |= torch_key_padding - - attn_bias, attention_mask = model.model.transformer._attn_bias( - device=x.device, dtype=x.dtype, attention_mask=attention_mask) - - for block in model.model.transformer.blocks: - a = block.norm_1(x) - b, attention_weights, _ = block.attn( - a, - past_key_value=None, - attn_bias=attn_bias, - attention_mask=attention_mask, - is_causal=model.model.transformer.is_causal, - needs_weights=True) - - zerod_weights = (attention_weights == 0) - assert torch.equal(expected_zerod_weights.expand(*zerod_weights.shape), - zerod_weights) - x = x + block.resid_attn_dropout(b) - m = block.norm_2(x) - n = block.ffn(m) - x = x + block.resid_ffn_dropout(n) + with get_precision_context(test_cfg.precision): + tok_emb = model.model.transformer.wte(input_ids) + pos_emb = model.model.transformer.wpe(pos) + x = model.model.transformer.emb_drop(tok_emb + pos_emb) + + # basically the attention mask should be a tensor shape (bsz, seqlen, seqlen) + # wih -inf along the upper triangle as well as wherever there are any pad tokens + # and with 0 everywhere else + expected_zerod_weights = nn.Transformer.generate_square_subsequent_mask(test_cfg.max_seq_len, device=test_cfg.device)\ + .reshape(1, test_cfg.max_seq_len, test_cfg.max_seq_len) + expected_zerod_weights = torch.isneginf( + torch.cat(batch_size * [expected_zerod_weights])) + torch_key_padding = torch.cat( # type: ignore + test_cfg.max_seq_len * + [(~attention_mask).reshape(batch_size, 1, test_cfg.max_seq_len)], + axis=1) + expected_zerod_weights |= torch_key_padding + + attn_bias, attention_mask = model.model.transformer._attn_bias( + device=x.device, dtype=x.dtype, attention_mask=attention_mask) + + for block in model.model.transformer.blocks: + a = block.norm_1(x) + b, attention_weights, _ = block.attn( + a, + past_key_value=None, + attn_bias=attn_bias, + attention_mask=attention_mask, + is_causal=model.model.transformer.is_causal, + needs_weights=True) + + zerod_weights = (attention_weights == 0) + assert torch.equal( + expected_zerod_weights.expand(*zerod_weights.shape), + zerod_weights) + x = x + block.resid_attn_dropout(b) + m = block.norm_2(x) + n = block.ffn(m) + x = x + block.resid_ffn_dropout(n) def test_full_forward_and_backward_gpt2_small(batch_size: int = 2): @@ -424,7 +442,6 @@ def test_determinism(attn_impl: str, precision: torch.dtype, ffn_type: str, output_2 = model_2(batch) assert output_1.logits.allclose(output_2.logits, rtol=0.0, atol=0.0), f'differed at step {i}' - loss_1 = model_1.loss(output_1, batch) loss_2 = model_2.loss(output_2, batch) assert isinstance(loss_1, torch.Tensor) diff --git a/tests/models/test_mpt_gen.py b/tests/models/test_mpt_gen.py index 35f130cd46..00c6a1c7a8 100644 --- a/tests/models/test_mpt_gen.py +++ b/tests/models/test_mpt_gen.py @@ -142,6 +142,50 @@ def test_mpt_generate_callback(attn_impl: str, use_alibi: bool, trainer.logger.log_table.assert_called_once() +@pytest.mark.gpu +@pytest.mark.parametrize('device', ['cpu', 'gpu']) +@pytest.mark.parametrize('attn_impl', ['flash', 'torch']) +def test_gen_mpt_moe( + device: str, + attn_impl: str, + build_tiny_mpt: Callable[..., ComposerMPTCausalLM], + mpt_tokenizer: PreTrainedTokenizerBase, +): + if device == 'cpu': + pytest.skip(f'Megablocks is only impelmented on GPU only.') + composer_device = get_device(device) + + model = build_tiny_mpt( + attn_config={ + 'attn_impl': attn_impl, + 'attn_uses_sequence_id': False, + }, + expansion_ratio=1, + ffn_config={ + 'ffn_type': 'mb_dmoe', + 'memory_optimized_mlp': True, + 'moe_lbl_in_fp32': False, + 'moe_loss_weight': 0.01, + 'moe_num_experts': 4, + 'moe_top_k': 2, + 'moe_world_size': 1, + 'moe_weight_parallelism': False, + 'uniform_expert_assignment': False, + }, + ) + model = composer_device.module_to_device(model) + + model.eval() + + with get_precision_context('amp_bf16' if composer_device.name == + 'gpu' else 'fp32'): + _ = model.generate( + composer_device.tensor_to_device( + mpt_tokenizer('hello', return_tensors='pt')['input_ids']), + max_new_tokens=10, + ) + + @pytest.mark.gpu @pytest.mark.parametrize('attn_impl', ['flash', 'torch']) @pytest.mark.parametrize('use_alibi', [True, False])