diff --git a/llmfoundry/__init__.py b/llmfoundry/__init__.py index 922f738e9a..54a55d6e97 100644 --- a/llmfoundry/__init__.py +++ b/llmfoundry/__init__.py @@ -28,7 +28,7 @@ MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn, scaled_multihead_dot_product_attention) from llmfoundry.models.layers.blocks import MPTBlock -from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn +from llmfoundry.models.layers.ffn import MPTMLP from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig, MPTForCausalLM, MPTModel, MPTPreTrainedModel) from llmfoundry.tokenizers import TiktokenTokenizerWrapper @@ -37,9 +37,7 @@ 'build_finetuning_dataloader', 'Seq2SeqFinetuningCollator', 'MPTBlock', - 'FFN_CLASS_REGISTRY', 'MPTMLP', - 'build_ffn', 'MPTConfig', 'MPTPreTrainedModel', 'MPTModel', diff --git a/llmfoundry/layers_registry.py b/llmfoundry/layers_registry.py index 3c0a7ebd6e..19b3b1c5cf 100644 --- a/llmfoundry/layers_registry.py +++ b/llmfoundry/layers_registry.py @@ -26,6 +26,34 @@ entry_points=True, description=_fc_description) +_ffns_description = ( + 'The ffns registry is used to register functions that build ffn layers.' + + 'See ffn.py for examples.') +ffns = create_registry('llmfoundry', + 'ffns', + generic_type=Callable, + entry_points=True, + description=_ffns_description) + +_ffns_with_norm_description = ( + 'The ffns_with_norm registry is used to register functions that build ffn layers that apply a normalization layer.' + + 'See ffn.py for examples.') +ffns_with_norm = create_registry('llmfoundry', + 'ffns_with_norm', + generic_type=Callable, + entry_points=True, + description=_ffns_with_norm_description) + +_ffns_with_megablocks_description = ( + 'The ffns_with_megablocks registry is used to register functions that build ffn layers using MegaBlocks.' + + 'See ffn.py for examples.') +ffns_with_megablocks = create_registry( + 'llmfoundry', + 'ffns_with_megablocks', + generic_type=Callable, + entry_points=True, + description=_ffns_with_megablocks_description) + _attention_classes_description = ( 'The attention_classes registry is used to register classes that implement attention layers. See ' + 'attention.py for expected constructor signature.') @@ -47,6 +75,9 @@ __all__ = [ 'norms', + 'ffns', + 'ffns_with_norm', + 'ffns_with_megablocks', 'attention_classes', 'attention_implementations', 'fcs', diff --git a/llmfoundry/models/layers/__init__.py b/llmfoundry/models/layers/__init__.py index 5784fcd7e9..dca55098c4 100644 --- a/llmfoundry/models/layers/__init__.py +++ b/llmfoundry/models/layers/__init__.py @@ -8,7 +8,7 @@ from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding from llmfoundry.models.layers.fc import * -from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn +from llmfoundry.models.layers.ffn import MPTMLP from llmfoundry.models.layers.norm import LPLayerNorm __all__ = [ @@ -24,6 +24,4 @@ 'MPTBlock', 'LPLayerNorm', 'SharedEmbedding', - 'FFN_CLASS_REGISTRY', - 'build_ffn', ] diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 1ad9ec954f..40f349368f 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -8,9 +8,9 @@ import torch import torch.nn as nn -from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn +from llmfoundry.layers_registry import ffns_with_norm from llmfoundry.models.layers.layer_builders import (build_attention_layer, - build_norm) + build_ffn, build_norm) try: from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip @@ -73,12 +73,15 @@ def __init__( del kwargs # unused, just to capture any extra args from the config super().__init__() + ffn_type = ffn_config['ffn_type'] + ffn_has_norm = ffn_type in ffns_with_norm + if self.fuse_norm_attn_norm: self.norm_attn_norm = FusedNormAttentionNorm( d_model=d_model, n_heads=n_heads, attn_config=attn_config, - ffn_config=ffn_config, + ffn_has_norm=ffn_has_norm, fc_type=fc_type, resid_pdrop=resid_pdrop, norm_type=norm_type, @@ -116,8 +119,7 @@ def __init__( }, ) self.norm_2 = None - if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], - '_has_norm', False): + if not ffn_has_norm: self.norm_2 = build_norm( name=norm_type.lower(), normalized_shape=d_model, @@ -125,12 +127,14 @@ def __init__( ) self.ffn = build_ffn( + name=ffn_type, d_model=d_model, expansion_ratio=expansion_ratio, device=device, bias=not no_bias, - **ffn_config, + ffn_kwargs=ffn_config, ) + self.resid_attn_dropout = nn.Dropout(resid_pdrop) self.resid_ffn_dropout = nn.Dropout(resid_pdrop) self.use_pad_tok_in_ffn = use_pad_tok_in_ffn @@ -198,7 +202,7 @@ def __init__( d_model: int, n_heads: int, attn_config: Optional[Dict] = None, - ffn_config: Optional[Dict] = None, + ffn_has_norm: bool = False, fc_type: str = 'torch', resid_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', @@ -208,7 +212,6 @@ def __init__( ): super().__init__() assert attn_config is not None - assert ffn_config is not None assert isinstance(attn_config['attn_type'], str) # necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs @@ -238,9 +241,9 @@ def __init__( **attn_config_subset_for_attn_class }, ) + self.norm_2 = None - if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm', - False): + if not ffn_has_norm: self.norm_2 = build_norm( name=norm_type.lower(), normalized_shape=d_model, diff --git a/llmfoundry/models/layers/dmoe.py b/llmfoundry/models/layers/dmoe.py index 1a981b61c5..19cd67b8aa 100644 --- a/llmfoundry/models/layers/dmoe.py +++ b/llmfoundry/models/layers/dmoe.py @@ -1,7 +1,7 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import Callable +from typing import Callable, Optional import torch @@ -24,7 +24,8 @@ class LearnedRouter(torch.nn.Module): def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int, moe_jitter_eps: float, moe_normalize_expert_weights: bool, - uniform_expert_assignment: bool, device: torch.device) -> None: + uniform_expert_assignment: bool, + device: Optional[torch.device]) -> None: super().__init__() self.hidden_size: int = hidden_size self.moe_num_experts: int = moe_num_experts @@ -84,7 +85,7 @@ def __init__( ffn_hidden_size: int, moe_num_experts: int, activation_fn: Callable, - device: torch.device, + device: Optional[torch.device], ) -> None: super().__init__() @@ -117,9 +118,14 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: class GLU(torch.nn.Module): - def __init__(self, hidden_size: int, ffn_hidden_size: int, - moe_num_experts: int, activation_fn: Callable, - device: torch.device): + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + moe_num_experts: int, + activation_fn: Callable, + device: Optional[torch.device], + ): super().__init__() self.hidden_size = hidden_size self.ffn_hidden_size = ffn_hidden_size @@ -157,9 +163,16 @@ def forward(self, x: torch.Tensor, expert_idx: torch.Tensor): class DroplessMLP(torch.nn.Module): - def __init__(self, hidden_size: int, ffn_hidden_size: int, mlp_type: str, - moe_num_experts: int, activation_fn: Callable, bias: bool, - device: torch.device): + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + mlp_type: str, + moe_num_experts: int, + activation_fn: Callable, + bias: bool, + device: Optional[torch.device], + ): super().__init__() self.moe_num_experts = moe_num_experts @@ -209,12 +222,20 @@ def forward(self, x: torch.Tensor, scores: torch.Tensor, class dMoE(torch.nn.Module): - def __init__(self, hidden_size: int, ffn_hidden_size: int, - moe_num_experts: int, moe_top_k: int, mlp_type: str, - activation_fn: Callable, moe_jitter_eps: float, - moe_normalize_expert_weights: bool, - uniform_expert_assignment: bool, bias: bool, - device: torch.device): + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + moe_num_experts: int, + moe_top_k: int, + mlp_type: str, + activation_fn: Callable, + moe_jitter_eps: float, + moe_normalize_expert_weights: bool, + uniform_expert_assignment: bool, + bias: bool, + device: Optional[torch.device], + ): super().__init__() # Token router. diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index f0b499875a..fb663b4c3c 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -10,8 +10,11 @@ import torch import torch.nn as nn +from torch.distributed import ProcessGroup from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Shard +from llmfoundry.layers_registry import (ffns, ffns_with_megablocks, + ffns_with_norm) from llmfoundry.models.layers.dmoe import dMoE from llmfoundry.models.layers.layer_builders import build_fc @@ -172,25 +175,47 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) -FFN_CLASS_REGISTRY = { - 'mptmlp': MPTMLP, - 'mptglu': MPTGLU, - 'torch_dmoe': dMoE, -} +def build_mptglu( + d_model: int, + expansion_ratio: Union[int, float], + fc_type: str = 'torch', + ffn_hidden_size: Optional[int] = None, + ffn_act_fn: Optional[dict] = None, + device: Optional[str] = None, + bias: bool = True, +) -> nn.Module: + return MPTGLU( + d_model=d_model, + expansion_ratio=expansion_ratio, + fc_type=fc_type, + act_fn=resolve_ffn_act_fn(ffn_act_fn), + ffn_hidden_size=ffn_hidden_size, + device=device, + bias=bias, + ) -if is_te_imported: - import transformer_engine.pytorch as te - te.LayerNormMLP._has_norm = True - FFN_CLASS_REGISTRY['te_ln_mlp'] = te.LayerNormMLP -if is_megablocks_imported: - import megablocks - - FFN_CLASS_REGISTRY['mb_moe'] = megablocks.layers.moe.MoE - FFN_CLASS_REGISTRY['mb_dmoe'] = megablocks.layers.dmoe.dMoE +def build_mptmlp( + d_model: int, + expansion_ratio: Union[int, float], + fc_type: str = 'torch', + ffn_hidden_size: Optional[int] = None, + ffn_act_fn: Optional[dict] = None, + device: Optional[str] = None, + bias: bool = True, +) -> nn.Module: + return MPTMLP( + d_model=d_model, + expansion_ratio=expansion_ratio, + fc_type=fc_type, + act_fn=resolve_ffn_act_fn(ffn_act_fn), + ffn_hidden_size=ffn_hidden_size, + device=device, + bias=bias, + ) -def build_ffn( +def build_te_ln_mlp( d_model: int, expansion_ratio: Union[int, float], fc_type: str = 'torch', @@ -200,131 +225,225 @@ def build_ffn( bias: bool = True, **kwargs: Any, ) -> nn.Module: - ffn_type = kwargs.pop('ffn_type') - if ffn_type in ['mptmlp', 'mptglu']: - if len(kwargs) > 0: - raise ValueError( - f'MPTMLP (or MPTGLU) got an unexpected keyword argument: {kwargs}' - ) - return FFN_CLASS_REGISTRY[ffn_type]( - d_model=d_model, - expansion_ratio=expansion_ratio, - fc_type=fc_type, - act_fn=resolve_ffn_act_fn(ffn_act_fn), - ffn_hidden_size=ffn_hidden_size, - device=device, - bias=bias, + assert te is not None + ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, + ffn_hidden_size) + if ffn_act_fn is not None: + raise ValueError( + f'Transformer Engine block does not support custom activation functions.' ) - elif ffn_type == 'te_ln_mlp': - if te is None: - raise RuntimeError( - 'Requirements for TransformerEngine not installed; see install instructions in `README.md`.' - ) - ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, - ffn_hidden_size) - if ffn_act_fn is not None: - raise ValueError( - f'Transformer Engine block does not support custom activation functions.' - ) - return te.LayerNormMLP( - hidden_size=d_model, - ffn_hidden_size=ffn_hidden_size, - bias=bias, - **kwargs, - ) - elif ffn_type in ('mb_moe', 'mb_dmoe'): - if megablocks is None: - raise RuntimeError( - 'Requirements for megablocks not installed; see install instructions in `README.md`.' - ) - args = kwargs['args'] - args.bias = bias - args.hidden_size = d_model - args.device = device + return te.LayerNormMLP( + hidden_size=d_model, + ffn_hidden_size=ffn_hidden_size, + bias=bias, + **kwargs, + ) + + +def build_torch_dmoe( + d_model: int, + expansion_ratio: Union[int, float], + ffn_hidden_size: Optional[int] = None, + ffn_act_fn: Optional[dict] = None, + device: Optional[str] = None, + bias: bool = True, + **kwargs: Any, +) -> nn.Module: + moe_num_experts = kwargs.pop('moe_num_experts') + moe_top_k = kwargs.pop('moe_top_k') + mlp_type = kwargs.pop('mlp_type') + moe_jitter_eps = kwargs.pop('moe_jitter_eps') + moe_normalize_expert_weights = kwargs.pop('moe_normalize_expert_weights') + uniform_expert_assignment = kwargs.pop('uniform_expert_assignment') + + fc_type = kwargs.pop('fc_type', 'torch') + del fc_type # Unused + + if len(kwargs) > 0: + raise ValueError(f'Invalid arguments to torch dmoe: {kwargs}.') + + return dMoE( + hidden_size=d_model, + ffn_hidden_size=resolve_ffn_hidden_size(d_model, expansion_ratio, + ffn_hidden_size), + moe_num_experts=moe_num_experts, + moe_top_k=moe_top_k, + mlp_type=mlp_type, + bias=bias, + moe_jitter_eps=moe_jitter_eps, + activation_fn=resolve_ffn_act_fn(ffn_act_fn), + moe_normalize_expert_weights=moe_normalize_expert_weights, + uniform_expert_assignment=uniform_expert_assignment, + device=torch.device(device) if device is not None else None, + ) - ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, - ffn_hidden_size) - args.ffn_hidden_size = ffn_hidden_size - - if ffn_act_fn is not None: - args.activation_fn = resolve_ffn_act_fn(ffn_act_fn) - - moe_world_size = 1 - expert_parallel_group = args.expert_parallel_group - if expert_parallel_group is not None: - moe_world_size = expert_parallel_group.size() - if kwargs.get('moe_world_size') != moe_world_size: - raise RuntimeError( - f'MoE expert_parallel_group configured with incorrect world size.' - ) - if ffn_type == 'mb_moe': - ffn = megablocks.layers.moe.MoE(args) - - # Fused initialization setup - # For param_init_fn, enables shape based init of stacked layers - ffn.experts.mlp._stack_dim = 0 - elif ffn_type == 'mb_dmoe': - ffn = megablocks.layers.dmoe.dMoE(args) - - # Fused initialization setup - # For param_init_fn, enables shape based init of fused layers - n_exp = min(1, args.moe_num_experts // moe_world_size) - ffn.experts.mlp._fused = (0, [ - (n + 1) * args.ffn_hidden_size for n in range(n_exp - 1) - ]) +def _mb_setup_args( + d_model: int, + expansion_ratio: Union[int, float], + ffn_hidden_size: Optional[int], + ffn_act_fn: Optional[dict], + device: Optional[str], + bias: bool, + kwargs: dict[str, Any], +) -> tuple['megablocks.layers.arguments.Arguments', int, ProcessGroup]: + if megablocks is None: + raise RuntimeError( + 'Requirements for megablocks not installed; see install instructions in `README.md`.' + ) + args = kwargs['args'] + args.bias = bias + args.hidden_size = d_model + args.device = device + + ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, + ffn_hidden_size) + args.ffn_hidden_size = ffn_hidden_size + + if ffn_act_fn is not None: + args.activation_fn = resolve_ffn_act_fn(ffn_act_fn) + + moe_world_size = 1 + expert_parallel_group = args.expert_parallel_group + if expert_parallel_group is not None: + moe_world_size = expert_parallel_group.size() + if kwargs.get('moe_world_size') != moe_world_size: + raise RuntimeError( + f'MoE expert_parallel_group configured with incorrect world size.') + + return args, moe_world_size, expert_parallel_group + + +def _patch_ffn_mb( + ffn: nn.Module, + moe_world_size: int, + expert_parallel_group: ProcessGroup, + device_mesh: DeviceMesh, + args: 'megablocks.layers.arguments.Arguments', +): + # Attach args to MLP directly for use in param_init_fn + ffn.experts.mlp.hidden_size = args.ffn_hidden_size + ffn.experts.mlp.expert_parallel_group = expert_parallel_group + ffn.experts.mlp.weight_parallel_group = args.weight_parallel_group + + if moe_world_size > 1: + expert_mesh = device_mesh['expert_parallel'] + expert_placements: List[Placement] = [Shard(0)] + # Register in two loops as you cannot overwrite parameters while iterating over named_parameters() + dtensorified_params = [ + (name, + dtensorify_param(param=parameter, + mesh=expert_mesh, + placements=expert_placements)) + for name, parameter in ffn.experts.mlp.named_parameters() + ] + for name, dtensorified_param in dtensorified_params: + ffn.experts.mlp.register_parameter(name, dtensorified_param) + + if device_mesh.mesh.ndim == 2: + submesh = device_mesh['weight_parallel'] + elif device_mesh.mesh.ndim == 3: + raise RuntimeError(f'HSDP + MoE is not supported.') else: - raise RuntimeError(f'Invalid ffn_type option: {ffn_type}.') - - # Attach args to MLP directly for use in param_init_fn - ffn.experts.mlp.hidden_size = args.ffn_hidden_size - ffn.experts.mlp.expert_parallel_group = expert_parallel_group - ffn.experts.mlp.weight_parallel_group = args.weight_parallel_group - - if moe_world_size > 1: - device_mesh = kwargs['device_mesh'] - - expert_mesh = device_mesh['expert_parallel'] - expert_placements: List[Placement] = [Shard(0)] - # Register in two loops as you cannot overwrite parameters while iterating over named_parameters() - dtensorified_params = [ - (name, - dtensorify_param(param=parameter, - mesh=expert_mesh, - placements=expert_placements)) - for name, parameter in ffn.experts.mlp.named_parameters() - ] - for name, dtensorified_param in dtensorified_params: - ffn.experts.mlp.register_parameter(name, dtensorified_param) - - device_mesh = kwargs['device_mesh'] - if device_mesh.mesh.ndim == 2: - submesh = device_mesh['weight_parallel'] - elif device_mesh.mesh.ndim == 3: - raise RuntimeError(f'HSDP + MoE is not supported.') - else: - raise ValueError( - f'{device_mesh.mesh.ndim=} not supported for MoE.') - - ffn.experts._fsdp_kwargs_dict = { - 'device_mesh': submesh, - } - return ffn - elif ffn_type == 'torch_dmoe': - return dMoE( - hidden_size=d_model, - ffn_hidden_size=resolve_ffn_hidden_size(d_model, expansion_ratio, - ffn_hidden_size), - moe_num_experts=kwargs.pop('moe_num_experts'), - moe_top_k=kwargs.pop('moe_top_k'), - mlp_type=kwargs.pop('mlp_type'), - bias=bias, - moe_jitter_eps=kwargs.pop('moe_jitter_eps'), - activation_fn=resolve_ffn_act_fn(ffn_act_fn), - moe_normalize_expert_weights=kwargs.pop( - 'moe_normalize_expert_weights'), - uniform_expert_assignment=kwargs.pop('uniform_expert_assignment'), - device=device, # pyright: ignore[reportGeneralTypeIssues] + raise ValueError(f'{device_mesh.mesh.ndim=} not supported for MoE.') + + ffn.experts._fsdp_kwargs_dict = { + 'device_mesh': submesh, + } + + +def build_mb_moe( + d_model: int, + expansion_ratio: Union[int, float], + ffn_hidden_size: Optional[int] = None, + ffn_act_fn: Optional[dict] = None, + device: Optional[str] = None, + bias: bool = True, + **kwargs: Any, +) -> nn.Module: + if not is_megablocks_imported: + raise RuntimeError( + 'Requirements for megablocks not installed; see install instructions in `README.md`.' + ) + + args, moe_world_size, expert_parallel_group = _mb_setup_args( + d_model=d_model, + expansion_ratio=expansion_ratio, + ffn_hidden_size=ffn_hidden_size, + ffn_act_fn=ffn_act_fn, + device=device, + bias=bias, + kwargs=kwargs, + ) + + ffn = megablocks.layers.moe.MoE(args) + + # Fused initialization setup + # For param_init_fn, enables shape based init of stacked layers + ffn.experts.mlp._stack_dim = 0 + + _patch_ffn_mb( + ffn=ffn, + moe_world_size=moe_world_size, + expert_parallel_group=expert_parallel_group, + device_mesh=kwargs['device_mesh'], + args=args, + ) + + return ffn + + +def build_mb_dmoe( + d_model: int, + expansion_ratio: Union[int, float], + ffn_hidden_size: Optional[int] = None, + ffn_act_fn: Optional[dict] = None, + device: Optional[str] = None, + bias: bool = True, + **kwargs: Any, +) -> nn.Module: + if not is_megablocks_imported: + raise RuntimeError( + 'Requirements for megablocks not installed; see install instructions in `README.md`.' ) - raise ValueError(f'{ffn_type=} not recognized.') + args, moe_world_size, expert_parallel_group = _mb_setup_args( + d_model=d_model, + expansion_ratio=expansion_ratio, + ffn_hidden_size=ffn_hidden_size, + ffn_act_fn=ffn_act_fn, + device=device, + bias=bias, + kwargs=kwargs, + ) + + ffn = megablocks.layers.dmoe.dMoE(args) + + # Fused initialization setup + # For param_init_fn, enables shape based init of fused layers + n_exp = min(1, args.moe_num_experts // moe_world_size) + ffn.experts.mlp._fused = (0, [ + (n + 1) * args.ffn_hidden_size for n in range(n_exp - 1) + ]) + + _patch_ffn_mb( + ffn=ffn, + moe_world_size=moe_world_size, + expert_parallel_group=expert_parallel_group, + device_mesh=kwargs['device_mesh'], + args=args, + ) + + return ffn + + +ffns.register('mptglu', func=build_mptglu) +ffns.register('mptmlp', func=build_mptmlp) +ffns.register('torch_dmoe', func=build_torch_dmoe) + +if is_te_imported: + ffns_with_norm.register('te_ln_mlp', func=build_te_ln_mlp) + +if is_megablocks_imported: + ffns_with_megablocks.register('mb_moe', func=build_mb_moe) + ffns_with_megablocks.register('mb_dmoe', func=build_mb_dmoe) diff --git a/llmfoundry/models/layers/layer_builders.py b/llmfoundry/models/layers/layer_builders.py index 6a725d469a..425fcaf862 100644 --- a/llmfoundry/models/layers/layer_builders.py +++ b/llmfoundry/models/layers/layer_builders.py @@ -5,7 +5,9 @@ import torch -from llmfoundry.layers_registry import attention_classes, fcs, norms +from llmfoundry.layers_registry import (attention_classes, fcs, ffns, + ffns_with_megablocks, ffns_with_norm, + norms) from llmfoundry.utils.registry_utils import construct_from_registry @@ -25,6 +27,50 @@ def build_norm( kwargs=kwargs) +def build_ffn( + name: str, + d_model: int, + expansion_ratio: float, + device: Optional[str], + bias: bool, + ffn_kwargs: Dict[str, Any], +): + + registry_to_use = ffns + if name in ffns_with_norm: + registry_to_use = ffns_with_norm + + if name in ffns_with_megablocks: + registry_to_use = ffns_with_megablocks + + kwargs = { + 'd_model': d_model, + 'expansion_ratio': expansion_ratio, + 'device': device, + 'bias': bias, + **{k: v for k, v in ffn_kwargs.items() if k != 'ffn_type'}, + } + + def _validation_function(maybe_module: Any): + if not isinstance(maybe_module, torch.nn.Module): + raise ValueError(f'Function {name} must return a torch.nn.Module.') + + result = construct_from_registry( + name=name, + registry=registry_to_use, + post_validation_function=_validation_function, + partial_function=False, + kwargs=kwargs) + + if name in ffns_with_norm: + result._has_norm = True + + if name in ffns_with_megablocks: + result._uses_megablocks = True + + return result + + def build_attention_layer( name: str, attn_kwargs: Dict[str, Any], diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 4b98fa611d..dbee232f3d 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -8,6 +8,7 @@ from transformers import PretrainedConfig +from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.models.layers.attention import (check_alibi_support, is_flash_v2_installed) from llmfoundry.models.layers.blocks import attn_config_defaults @@ -17,8 +18,7 @@ # Otherwise, certain modules are missing. # isort: off from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note) -from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note) -from llmfoundry.models.layers.layer_builders import build_norm, build_fc # type: ignore (see note) +from llmfoundry.models.layers.layer_builders import build_norm, build_fc, build_ffn # type: ignore (see note) from llmfoundry.models.layers.dmoe import dMoE # type: ignore (see note) from llmfoundry.layers_registry import norms # type: ignore (see note) from llmfoundry.utils.registry_utils import construct_from_registry # type: ignore (see note) @@ -290,7 +290,7 @@ def _validate_config(self) -> None: ) elif self.ffn_config['ffn_type'] in ['mptmlp', 'mptglu']: self.ffn_config['fc_type'] = self.fc_type - elif self.ffn_config['ffn_type'] in ['mb_moe', 'mb_dmoe']: + elif self.ffn_config['ffn_type'] in ffns_with_megablocks: self.ffn_config['return_bias'] = False elif self.ffn_config['ffn_type'] == 'te_ln_mlp': self.ffn_config['bias'] = not self.no_bias diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 4a8f3943af..124ab3db3e 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -20,6 +20,7 @@ from composer.models import HuggingFaceModel from composer.utils import dist +from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.models.layers.attention import is_flash_v2_installed if is_flash_v2_installed(): @@ -47,7 +48,6 @@ build_attn_bias, gen_slopes) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding -from llmfoundry.models.layers.ffn import build_ffn as build_ffn from llmfoundry.models.layers.layer_builders import build_norm from llmfoundry.models.mpt.configuration_mpt import MPTConfig from llmfoundry.models.utils.config_moe_args import config_moe_args @@ -64,6 +64,7 @@ generic_param_init_fn_, # type: ignore (see note) MODEL_INIT_REGISTRY, ) +from llmfoundry.models.layers.ffn import resolve_ffn_act_fn # type: ignore (see note) from llmfoundry.models.utils.act_ckpt import (pass_on_block_idx, build_act_ckpt_mod_to_blocks, @@ -324,7 +325,7 @@ def __init__(self, config: MPTConfig): self.emb_drop = nn.Dropout(config.emb_pdrop) self.mb_args = None block_args = config.to_dict() - if block_args['ffn_config']['ffn_type'] in ('mb_moe', 'mb_dmoe'): + if block_args['ffn_config']['ffn_type'] in ffns_with_megablocks: block_args['ffn_config'] = config_moe_args( block_args['ffn_config'], config.d_model, @@ -332,6 +333,7 @@ def __init__(self, config: MPTConfig): config.n_layers, ) self.mb_args = block_args['ffn_config'].get('args') + self.blocks = nn.ModuleList([ MPTBlock( device=config.init_device, @@ -1026,7 +1028,7 @@ def get_targets(self, batch: Mapping) -> torch.Tensor: return targets def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: - if self.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'): + if self.config.ffn_config['ffn_type'] in ffns_with_megablocks: # Clear MegaBlocks MoE load balancing loss cache try: # Add try/catch to avoid transformers complaining and raising errors from megablocks.layers.moe import clear_load_balancing_loss @@ -1053,7 +1055,7 @@ def loss(self, outputs: CausalLMOutputWithPast, else: loss = losses.sum() / (targets != self.loss_fn.ignore_index).sum() - if self.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'): + if self.config.ffn_config['ffn_type'] in ffns_with_megablocks: # MegaBlocks MoE load balancing loss try: # Add try/catch to avoid transformers complaining and raising errors from megablocks.layers.moe import batched_load_balancing_loss diff --git a/llmfoundry/models/utils/act_ckpt.py b/llmfoundry/models/utils/act_ckpt.py index fea68492c1..e6cd8bdc58 100644 --- a/llmfoundry/models/utils/act_ckpt.py +++ b/llmfoundry/models/utils/act_ckpt.py @@ -5,9 +5,10 @@ import torch -from llmfoundry.layers_registry import attention_classes, norms +from llmfoundry.layers_registry import (attention_classes, ffns, + ffns_with_megablocks, ffns_with_norm, + norms) from llmfoundry.models.layers.blocks import FusedNormAttentionNorm, MPTBlock -from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY def pass_on_block_idx(parent: torch.nn.Module): @@ -28,14 +29,19 @@ def get_act_ckpt_module(mod_name: str) -> Any: mod_type = attention_classes.get(mod_name) elif mod_name.lower() == 'norm_attn_norm': mod_type = FusedNormAttentionNorm - elif mod_name in FFN_CLASS_REGISTRY: - mod_type = FFN_CLASS_REGISTRY[mod_name] + elif mod_name in ffns: + mod_type = ffns.get(mod_name) + elif mod_name in ffns_with_norm: + mod_type = ffns_with_norm.get(mod_name) + elif mod_name in ffns_with_megablocks: + mod_type = ffns_with_megablocks.get(mod_name) elif mod_name in norms: mod_type = norms.get(mod_name) else: msg = ', '.join( - list(attention_classes.get_all()) + - list(FFN_CLASS_REGISTRY.keys()) + list(norms.get_all()) + + list(attention_classes.keys()) + list(ffns.get_all()) + + list(ffns_with_norm.get_all()) + + list(ffns_with_megablocks.get_all()) + list(norms.get_all()) + ['MPTBlock']) raise ValueError( f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.' diff --git a/llmfoundry/models/utils/config_moe_args.py b/llmfoundry/models/utils/config_moe_args.py index 1f7132c281..4de9a47bbc 100644 --- a/llmfoundry/models/utils/config_moe_args.py +++ b/llmfoundry/models/utils/config_moe_args.py @@ -9,6 +9,7 @@ from packaging import version from torch import distributed +from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.models.layers.ffn import resolve_ffn_hidden_size @@ -177,7 +178,7 @@ def config_moe_args( Returns: ffn_config (dict): FFN configuration with MoE configured. """ - if ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'): + if ffn_config['ffn_type'] in ffns_with_megablocks: return config_megablocks_moe_args( ffn_config=ffn_config, d_model=d_model, diff --git a/llmfoundry/models/utils/mpt_param_count.py b/llmfoundry/models/utils/mpt_param_count.py index d90929713b..cb1a5c0935 100644 --- a/llmfoundry/models/utils/mpt_param_count.py +++ b/llmfoundry/models/utils/mpt_param_count.py @@ -16,6 +16,8 @@ from torch import Tensor, nn from torch.distributed._tensor import DTensor +from llmfoundry.layers_registry import ffns_with_megablocks + def module_n_params(module: nn.Module) -> int: """Gets the number of parameters in this module excluding child modules. @@ -127,7 +129,7 @@ def megablocks_n_active_params(mpt_model) -> int: # type: ignore def mpt_get_total_params(mpt_model) -> int: # type: ignore - """Calculates the total paramter count of an MPT model. + """Calculates the total parameter count of an MPT model. Note: Must be called before model parameters are sharded by FSDP. @@ -138,14 +140,14 @@ def mpt_get_total_params(mpt_model) -> int: # type: ignore Returns: An int for the total number of parameters in this MPT model. """ - if mpt_model.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'): + if mpt_model.config.ffn_config['ffn_type'] in ffns_with_megablocks: return megablocks_n_total_params(mpt_model) else: return sum(p.numel() for p in mpt_model.parameters()) def mpt_get_active_params(mpt_model) -> int: # type: ignore - """Calculates the total paramter count of an MPT model. + """Calculates the total parameter count of an MPT model. Note: Must be called before model parameters are sharded by FSDP. @@ -156,7 +158,7 @@ def mpt_get_active_params(mpt_model) -> int: # type: ignore Returns: An int for the active number of parameters in this MPT model. """ - if mpt_model.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'): + if mpt_model.config.ffn_config['ffn_type'] in ffns_with_megablocks: params = megablocks_n_active_params(mpt_model) else: params = sum(p.numel() for p in mpt_model.parameters()) diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index ef6629be10..64ed5d7b65 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -13,7 +13,9 @@ from llmfoundry.interfaces import CallbackWithConfig from llmfoundry.layers_registry import (attention_classes, - attention_implementations, fcs, norms) + attention_implementations, fcs, ffns, + ffns_with_megablocks, ffns_with_norm, + norms) from llmfoundry.utils.registry_utils import create_registry _loggers_description = ( @@ -131,6 +133,9 @@ 'metrics', 'dataloaders', 'norms', + 'ffns', + 'ffns_with_norm', + 'ffns_with_megablocks', 'attention_classes', 'attention_implementations', 'fcs', diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index d2c3b733c0..a4fd005c3a 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -11,6 +11,7 @@ from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om +from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.models.utils import init_empty_weights log = logging.getLogger(__name__) @@ -131,7 +132,7 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): # Set ffn_config.device_mesh to fsdp_config.device_mesh if fsdp_config is not None and 'device_mesh' in fsdp_config and 'ffn_config' in model_cfg and model_cfg[ - 'ffn_config'].get('ffn_type', None) in {'mb_moe', 'mb_dmoe'}: + 'ffn_config'].get('ffn_type', None) in ffns_with_megablocks: # Raise ValueError if not using device mesh with MoE expert parallelism if fsdp_config['device_mesh'] is None and model_cfg['ffn_config'].get( 'moe_world_size', 1) > 1: diff --git a/llmfoundry/utils/registry_utils.py b/llmfoundry/utils/registry_utils.py index d9c23e6f26..0eeefbae74 100644 --- a/llmfoundry/utils/registry_utils.py +++ b/llmfoundry/utils/registry_utils.py @@ -1,9 +1,11 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import copy import functools import importlib.util import os +from contextlib import contextmanager from pathlib import Path from types import ModuleType from typing import (Any, Callable, Dict, Generic, Optional, Sequence, Type, @@ -143,7 +145,7 @@ def construct_from_registry( ) if post_validation_function is not None: - post_validation_function(registered_constructor) + post_validation_function(constructed_item) return constructed_item @@ -174,3 +176,13 @@ def import_file(loc: Union[str, Path]) -> ModuleType: except Exception as e: raise RuntimeError(f'Error executing {loc}') from e return module + + +@contextmanager +def save_registry(): + """Save the registry state and restore after the context manager exits.""" + saved_registry_state = copy.deepcopy(catalogue.REGISTRY) + + yield + + catalogue.REGISTRY = saved_registry_state diff --git a/scripts/train/train.py b/scripts/train/train.py index 76156d4577..a49ae4e26d 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -26,6 +26,7 @@ install() from llmfoundry.callbacks import AsyncEval from llmfoundry.data.dataloader import build_dataloader +from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.utils.builders import (add_metrics_to_eval_loaders, build_algorithm, build_callback, build_composer_model, build_evaluators, @@ -102,7 +103,7 @@ def validate_config(cfg: DictConfig): ) if cfg.model.get('ffn_config', {}).get('ffn_type', - 'mptmlp') in ('mb_moe', 'mb_dmoe'): + 'mptmlp') in ffns_with_megablocks: moe_world_size = cfg.model.get('ffn_config', {}).get('moe_world_size', 1) use_orig_params = cfg.get('fsdp_config', diff --git a/tests/a_scripts/train/test_train.py b/tests/a_scripts/train/test_train.py index ff885ac735..fe58a44459 100644 --- a/tests/a_scripts/train/test_train.py +++ b/tests/a_scripts/train/test_train.py @@ -147,6 +147,7 @@ def test_train_multi_eval(tmp_path: pathlib.Path): tuple) +@pytest.mark.gpu def test_validate_config(): conf_path: str = os.path.join( REPO_DIR, diff --git a/tests/fixtures/autouse.py b/tests/fixtures/autouse.py index ccbe1b69f7..16e3f8ad6f 100644 --- a/tests/fixtures/autouse.py +++ b/tests/fixtures/autouse.py @@ -9,11 +9,19 @@ import torch from composer.utils import dist, get_device, reproducibility +from llmfoundry.utils.registry_utils import save_registry + # Add llm-foundry repo root to path so we can import scripts in the tests REPO_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) sys.path.append(REPO_DIR) +@pytest.fixture(autouse=True) +def save_registry_fixture(): + with save_registry(): + yield + + @pytest.fixture(autouse=True) def initialize_dist(request: pytest.FixtureRequest): """Initialize the default PyTorch distributed process group for tests.""" diff --git a/tests/models/layers/test_dmoe.py b/tests/models/layers/test_dmoe.py index 9c15745793..c8e7ec3e67 100644 --- a/tests/models/layers/test_dmoe.py +++ b/tests/models/layers/test_dmoe.py @@ -239,6 +239,10 @@ def test_fwd_equal_dmoe(seqlen: int, precision: str, mlp_type: str): torch_dmoe_config = copy.deepcopy(mb_dmoe_config) torch_dmoe_config.ffn_config['ffn_type'] = 'torch_dmoe' + del torch_dmoe_config.ffn_config['moe_world_size'] + del torch_dmoe_config.ffn_config['fc_type'] + del torch_dmoe_config.ffn_config['moe_loss_weight'] + del torch_dmoe_config.ffn_config['return_bias'] mb_dmoe_model = MPTForCausalLM(mb_dmoe_config).to(device=device, dtype=dtype) diff --git a/tests/test_registry.py b/tests/test_registry.py index 29d8e137f3..aaba89c43d 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -31,6 +31,9 @@ def test_expected_registries_exist(): 'metrics', 'models', 'norms', + 'ffns', + 'ffns_with_norm', + 'ffns_with_megablocks', 'attention_classes', 'attention_implementations', 'fcs',