From 1e0b415abbddc31cb5aab25ef8de2e98133a8e1b Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 3 Apr 2024 01:56:10 -0700 Subject: [PATCH 01/32] ffn registry --- llmfoundry/__init__.py | 4 +- llmfoundry/layers_registry.py | 18 +++- llmfoundry/models/layers/__init__.py | 4 +- llmfoundry/models/layers/blocks.py | 23 ++--- llmfoundry/models/layers/ffn.py | 100 ++++++++++++--------- llmfoundry/models/layers/layer_builders.py | 41 ++++++++- llmfoundry/models/mpt/configuration_mpt.py | 3 +- llmfoundry/models/mpt/modeling_mpt.py | 2 +- llmfoundry/models/utils/act_ckpt.py | 14 +-- llmfoundry/utils/registry_utils.py | 2 +- 10 files changed, 142 insertions(+), 69 deletions(-) diff --git a/llmfoundry/__init__.py b/llmfoundry/__init__.py index 922f738e9a..54a55d6e97 100644 --- a/llmfoundry/__init__.py +++ b/llmfoundry/__init__.py @@ -28,7 +28,7 @@ MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn, scaled_multihead_dot_product_attention) from llmfoundry.models.layers.blocks import MPTBlock -from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn +from llmfoundry.models.layers.ffn import MPTMLP from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig, MPTForCausalLM, MPTModel, MPTPreTrainedModel) from llmfoundry.tokenizers import TiktokenTokenizerWrapper @@ -37,9 +37,7 @@ 'build_finetuning_dataloader', 'Seq2SeqFinetuningCollator', 'MPTBlock', - 'FFN_CLASS_REGISTRY', 'MPTMLP', - 'build_ffn', 'MPTConfig', 'MPTPreTrainedModel', 'MPTModel', diff --git a/llmfoundry/layers_registry.py b/llmfoundry/layers_registry.py index 9c7dabe128..d7aa46767c 100644 --- a/llmfoundry/layers_registry.py +++ b/llmfoundry/layers_registry.py @@ -1,7 +1,7 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import Type +from typing import Callable, Type import torch @@ -15,6 +15,22 @@ entry_points=True, description=_norm_description) +_ffns_description = """The ffns registry is used to register functions that build ffn layers.""" +ffns = create_registry('llmfoundry', + 'ffns', + generic_type=Callable, + entry_points=True, + description=_ffns_description) + +_ffns_with_norm_description = """The ffns_with_norm registry is used to register functions that build ffn layers that apply a normalization layer.""" +ffns_with_norm = create_registry('llmfoundry', + 'ffns_with_norm', + generic_type=Callable, + entry_points=True, + description=_ffns_with_norm_description) + __all__ = [ 'norms', + 'ffns', + 'ffns_with_norm', ] diff --git a/llmfoundry/models/layers/__init__.py b/llmfoundry/models/layers/__init__.py index 262f190b47..a76c28feba 100644 --- a/llmfoundry/models/layers/__init__.py +++ b/llmfoundry/models/layers/__init__.py @@ -8,7 +8,7 @@ from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY -from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn +from llmfoundry.models.layers.ffn import MPTMLP from llmfoundry.models.layers.norm import LPLayerNorm __all__ = [ @@ -26,6 +26,4 @@ 'LPLayerNorm', 'FC_CLASS_REGISTRY', 'SharedEmbedding', - 'FFN_CLASS_REGISTRY', - 'build_ffn', ] diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 42feb983d4..3024699181 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -9,8 +9,7 @@ import torch.nn as nn from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY -from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn -from llmfoundry.models.layers.layer_builders import build_norm +from llmfoundry.models.layers.layer_builders import build_ffn, build_norm try: from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip @@ -100,21 +99,25 @@ def __init__( **attn_config_subset_for_attn_class, bias=not no_bias, ) - self.norm_2 = None - if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm', - False): - self.norm_2 = build_norm( - name=norm_type.lower(), - normalized_shape=d_model, - device=device, - ) + + ffn_type = ffn_config.pop('ffn_type') self.ffn = build_ffn( + name=ffn_type, d_model=d_model, expansion_ratio=expansion_ratio, device=device, bias=not no_bias, **ffn_config, ) + + self.norm_2 = None + if not getattr(self.ffn, '_has_norm', False): + self.norm_2 = build_norm( + name=norm_type.lower(), + normalized_shape=d_model, + device=device, + ) + self.resid_attn_dropout = nn.Dropout(resid_pdrop) self.resid_ffn_dropout = nn.Dropout(resid_pdrop) diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index 9389cf385f..6932c93fda 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -11,6 +11,7 @@ import torch import torch.nn as nn +from llmfoundry.layers_registry import ffns, ffns_with_norm from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY try: @@ -149,17 +150,47 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) -FFN_CLASS_REGISTRY = { - 'mptmlp': MPTMLP, - 'mptglu': MPTGLU, -} - -if te is not None: - te.LayerNormMLP._has_norm = True - FFN_CLASS_REGISTRY['te_ln_mlp'] = te.LayerNormMLP - - -def build_ffn( +def build_mptglu( + d_model: int, + expansion_ratio: Union[int, float], + fc_type: str = 'torch', + ffn_hidden_size: Optional[int] = None, + ffn_act_fn: Optional[dict] = None, + device: Optional[str] = None, + bias: bool = True, +) -> nn.Module: + return MPTGLU( + d_model=d_model, + expansion_ratio=expansion_ratio, + fc_type=fc_type, + act_fn=resolve_ffn_act_fn(ffn_act_fn), + ffn_hidden_size=ffn_hidden_size, + device=device, + bias=bias, + ) + + +def build_mptmlp( + d_model: int, + expansion_ratio: Union[int, float], + fc_type: str = 'torch', + ffn_hidden_size: Optional[int] = None, + ffn_act_fn: Optional[dict] = None, + device: Optional[str] = None, + bias: bool = True, +) -> nn.Module: + return MPTMLP( + d_model=d_model, + expansion_ratio=expansion_ratio, + fc_type=fc_type, + act_fn=resolve_ffn_act_fn(ffn_act_fn), + ffn_hidden_size=ffn_hidden_size, + device=device, + bias=bias, + ) + + +def build_te_ln_mlp( d_model: int, expansion_ratio: Union[int, float], fc_type: str = 'torch', @@ -169,34 +200,23 @@ def build_ffn( bias: bool = True, **kwargs: Any, ) -> nn.Module: - ffn_type = kwargs.pop('ffn_type') - if ffn_type in ['mptmlp', 'mptglu']: - if len(kwargs) > 0: - raise ValueError( - f'MPTMLP (or MPTGLU) got an unexpected keyword argument: {kwargs}' - ) - return FFN_CLASS_REGISTRY[ffn_type]( - d_model=d_model, - expansion_ratio=expansion_ratio, - fc_type=fc_type, - act_fn=resolve_ffn_act_fn(ffn_act_fn), - ffn_hidden_size=ffn_hidden_size, - device=device, - bias=bias, - ) - elif ffn_type == 'te_ln_mlp': - assert te is not None - ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, - ffn_hidden_size) - if ffn_act_fn is not None: - raise ValueError( - f'Transformer Engine block does not support custom activation functions.' - ) - return te.LayerNormMLP( - hidden_size=d_model, - ffn_hidden_size=ffn_hidden_size, - bias=bias, - **kwargs, + assert te is not None + ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, + ffn_hidden_size) + if ffn_act_fn is not None: + raise ValueError( + f'Transformer Engine block does not support custom activation functions.' ) + return te.LayerNormMLP( + hidden_size=d_model, + ffn_hidden_size=ffn_hidden_size, + bias=bias, + **kwargs, + ) + - raise ValueError(f'{ffn_type=} not recognized.') +ffns.register('mptglu', func=build_mptglu) +ffns.register('mptmlp', func=build_mptmlp) + +if te is not None: + ffns_with_norm.register('te_ln_mlp', func=build_te_ln_mlp) diff --git a/llmfoundry/models/layers/layer_builders.py b/llmfoundry/models/layers/layer_builders.py index 23f5b89668..1f845d1dd4 100644 --- a/llmfoundry/models/layers/layer_builders.py +++ b/llmfoundry/models/layers/layer_builders.py @@ -1,11 +1,11 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch -from llmfoundry.layers_registry import norms +from llmfoundry.layers_registry import ffns, ffns_with_norm, norms from llmfoundry.utils.registry_utils import construct_from_registry @@ -23,3 +23,40 @@ def build_norm( registry=norms, pre_validation_function=torch.nn.Module, kwargs=kwargs) + + +def build_ffn( + name: str, + d_model: int, + expansion_ratio: float, + device: Optional[str], + bias: bool, + ffn_kwargs: Dict[str, Any], +): + registry_to_use = ffns + if name in ffns_with_norm: + registry_to_use = ffns_with_norm + + kwargs = { + 'd_model': d_model, + 'expansion_ratio': expansion_ratio, + 'device': device, + 'bias': bias, + **ffn_kwargs, + } + + def _validation_function(maybe_module: Any): + if not isinstance(maybe_module, torch.nn.Module): + raise ValueError(f'Function {name} must return a torch.nn.Module.') + + result = construct_from_registry( + name=name, + registry=registry_to_use, + post_validation_function=_validation_function, + partial_function=False, + kwargs=kwargs) + + if name in ffns_with_norm: + result._has_norm = True + + return result diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 2f58ea312e..b83b89d68b 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -18,8 +18,7 @@ # isort: off from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY # type: ignore (see note) from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note) -from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note) -from llmfoundry.models.layers.layer_builders import build_norm # type: ignore (see note) +from llmfoundry.models.layers.layer_builders import build_norm, build_ffn # type: ignore (see note) from llmfoundry.layers_registry import norms # type: ignore (see note) from llmfoundry.utils.registry_utils import construct_from_registry # type: ignore (see note) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index d54b797269..8dbef3a99e 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -46,7 +46,7 @@ build_attn_bias, gen_slopes) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding -from llmfoundry.models.layers.ffn import build_ffn as build_ffn +from llmfoundry.models.layers.layer_builders import build_ffn as build_ffn from llmfoundry.models.layers.layer_builders import build_norm from llmfoundry.models.mpt.configuration_mpt import MPTConfig diff --git a/llmfoundry/models/utils/act_ckpt.py b/llmfoundry/models/utils/act_ckpt.py index bde7c92bd7..23e8c09b17 100644 --- a/llmfoundry/models/utils/act_ckpt.py +++ b/llmfoundry/models/utils/act_ckpt.py @@ -5,10 +5,9 @@ import torch -from llmfoundry.layers_registry import norms +from llmfoundry.layers_registry import ffns, ffns_with_norm, norms from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY from llmfoundry.models.layers.blocks import MPTBlock -from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY def pass_on_block_idx(parent: torch.nn.Module): @@ -27,14 +26,17 @@ def get_act_ckpt_module(mod_name: str) -> Any: mod_type = MPTBlock elif mod_name in ATTN_CLASS_REGISTRY: mod_type = ATTN_CLASS_REGISTRY[mod_name] - elif mod_name in FFN_CLASS_REGISTRY: - mod_type = FFN_CLASS_REGISTRY[mod_name] + elif mod_name in ffns: + mod_type = ffns.get(mod_name) + elif mod_name in ffns_with_norm: + mod_type = ffns_with_norm.get(mod_name) elif mod_name in norms: mod_type = norms.get(mod_name) else: msg = ', '.join( - list(ATTN_CLASS_REGISTRY.keys()) + list(FFN_CLASS_REGISTRY.keys()) + - list(norms.get_all()) + ['MPTBlock']) + list(ATTN_CLASS_REGISTRY.keys()) + list(ffns.get_all()) + + list(ffns_with_norm.get_all()) + list(norms.get_all()) + + ['MPTBlock']) raise ValueError( f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.' ) diff --git a/llmfoundry/utils/registry_utils.py b/llmfoundry/utils/registry_utils.py index 0901ea198a..2693b3fc25 100644 --- a/llmfoundry/utils/registry_utils.py +++ b/llmfoundry/utils/registry_utils.py @@ -142,7 +142,7 @@ def construct_from_registry( ) if post_validation_function is not None: - post_validation_function(registered_constructor) + post_validation_function(constructed_item) return constructed_item From f327dc3e6bf5995fc337be21af9938bb367717d0 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 3 Apr 2024 20:27:28 -0700 Subject: [PATCH 02/32] fix --- llmfoundry/models/layers/blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 3024699181..4818116c18 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -107,7 +107,7 @@ def __init__( expansion_ratio=expansion_ratio, device=device, bias=not no_bias, - **ffn_config, + ffn_kwargs=ffn_config, ) self.norm_2 = None From 01e1d7ee76207651a178695de57efaebc9c3b5cb Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 4 Apr 2024 18:56:41 -0700 Subject: [PATCH 03/32] clean up --- llmfoundry/layers_registry.py | 8 ++++++-- llmfoundry/registry.py | 4 +++- tests/test_registry.py | 2 ++ 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/llmfoundry/layers_registry.py b/llmfoundry/layers_registry.py index d7aa46767c..4d529058db 100644 --- a/llmfoundry/layers_registry.py +++ b/llmfoundry/layers_registry.py @@ -15,14 +15,18 @@ entry_points=True, description=_norm_description) -_ffns_description = """The ffns registry is used to register functions that build ffn layers.""" +_ffns_description = ( + 'The ffns registry is used to register functions that build ffn layers.' + + 'See ffn.py for examples.') ffns = create_registry('llmfoundry', 'ffns', generic_type=Callable, entry_points=True, description=_ffns_description) -_ffns_with_norm_description = """The ffns_with_norm registry is used to register functions that build ffn layers that apply a normalization layer.""" +_ffns_with_norm_description = ( + 'The ffns_with_norm registry is used to register functions that build ffn layers that apply a normalization layer.' + + 'See ffn.py for examples.') ffns_with_norm = create_registry('llmfoundry', 'ffns_with_norm', generic_type=Callable, diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index 424075da3b..acde76a630 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -12,7 +12,7 @@ from transformers import PreTrainedTokenizerBase from llmfoundry.interfaces import CallbackWithConfig -from llmfoundry.layers_registry import norms +from llmfoundry.layers_registry import ffns, ffns_with_norm, norms from llmfoundry.utils.registry_utils import create_registry _loggers_description = ( @@ -121,4 +121,6 @@ 'metrics', 'dataloaders', 'norms', + 'ffns', + 'ffns_with_norm', ] diff --git a/tests/test_registry.py b/tests/test_registry.py index c93c7c9749..aee3aaf189 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -31,6 +31,8 @@ def test_expected_registries_exist(): 'metrics', 'models', 'norms', + 'ffns', + 'ffns_with_norm', } assert existing_registries == expected_registry_names From 80d67fa8a97951fa0d6486ea305cceed2bd8e615 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 16:28:22 -0700 Subject: [PATCH 04/32] pc --- llmfoundry/models/layers/ffn.py | 56 ++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index e0681d80a7..45f293f6e3 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -10,6 +10,7 @@ import torch import torch.nn as nn +from torch.distributed import ProcessGroup from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Shard from llmfoundry.layers_registry import ffns, ffns_with_norm @@ -234,6 +235,7 @@ def build_te_ln_mlp( **kwargs, ) + def build_torch_dmoe( d_model: int, expansion_ratio: Union[int, float], @@ -254,19 +256,20 @@ def build_torch_dmoe( raise ValueError(f'Invalid arguments to torch dmoe: {kwargs}.') return dMoE( - hidden_size=d_model, - ffn_hidden_size=resolve_ffn_hidden_size(d_model, expansion_ratio, - ffn_hidden_size), - moe_num_experts=moe_num_experts, - moe_top_k=moe_top_k, - mlp_type=mlp_type, - bias=bias, - moe_jitter_eps=moe_jitter_eps, - activation_fn=resolve_ffn_act_fn(ffn_act_fn), - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - device=device, # pyright: ignore[reportGeneralTypeIssues] - ) + hidden_size=d_model, + ffn_hidden_size=resolve_ffn_hidden_size(d_model, expansion_ratio, + ffn_hidden_size), + moe_num_experts=moe_num_experts, + moe_top_k=moe_top_k, + mlp_type=mlp_type, + bias=bias, + moe_jitter_eps=moe_jitter_eps, + activation_fn=resolve_ffn_act_fn(ffn_act_fn), + moe_normalize_expert_weights=moe_normalize_expert_weights, + uniform_expert_assignment=uniform_expert_assignment, + device=device, # pyright: ignore[reportGeneralTypeIssues] + ) + def _mb_setup_args( d_model: int, @@ -276,7 +279,7 @@ def _mb_setup_args( device: Optional[str], bias: bool, kwargs: dict[str, Any], -) -> tuple['megablocks.layers.arguments.Arguments', int, torch.distributed.ProcessGroup]: +) -> tuple['megablocks.layers.arguments.Arguments', int, ProcessGroup]: if megablocks is None: raise RuntimeError( 'Requirements for megablocks not installed; see install instructions in `README.md`.' @@ -287,7 +290,7 @@ def _mb_setup_args( args.device = device ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, - ffn_hidden_size) + ffn_hidden_size) args.ffn_hidden_size = ffn_hidden_size if ffn_act_fn is not None: @@ -299,15 +302,15 @@ def _mb_setup_args( moe_world_size = expert_parallel_group.size() if kwargs.get('moe_world_size') != moe_world_size: raise RuntimeError( - f'MoE expert_parallel_group configured with incorrect world size.' - ) - + f'MoE expert_parallel_group configured with incorrect world size.') + return args, moe_world_size, expert_parallel_group + def _patch_ffn_mb( ffn: nn.Module, moe_world_size: int, - expert_parallel_group: torch.distributed.ProcessGroup, + expert_parallel_group: ProcessGroup, device_mesh: DeviceMesh, args: 'megablocks.layers.arguments.Arguments', **kwargs: Any, @@ -325,9 +328,9 @@ def _patch_ffn_mb( # Register in two loops as you cannot overwrite parameters while iterating over named_parameters() dtensorified_params = [ (name, - dtensorify_param(param=parameter, - mesh=expert_mesh, - placements=expert_placements)) + dtensorify_param(param=parameter, + mesh=expert_mesh, + placements=expert_placements)) for name, parameter in ffn.experts.mlp.named_parameters() ] for name, dtensorified_param in dtensorified_params: @@ -339,15 +342,13 @@ def _patch_ffn_mb( elif device_mesh.mesh.ndim == 3: raise RuntimeError(f'HSDP + MoE is not supported.') else: - raise ValueError( - f'{device_mesh.mesh.ndim=} not supported for MoE.') + raise ValueError(f'{device_mesh.mesh.ndim=} not supported for MoE.') ffn.experts._fsdp_kwargs_dict = { 'device_mesh': submesh, } - def build_mb_moe( d_model: int, expansion_ratio: Union[int, float], @@ -387,6 +388,8 @@ def build_mb_moe( **kwargs, ) + return ffn + def build_mb_dmoe( d_model: int, @@ -430,6 +433,9 @@ def build_mb_dmoe( **kwargs, ) + return ffn + + ffns.register('mptglu', func=build_mptglu) ffns.register('mptmlp', func=build_mptmlp) ffns.register('torch_dmoe', func=build_torch_dmoe) From f7e4fec73a5bd20b6f045da8681b2df3412dc944 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 16:39:57 -0700 Subject: [PATCH 05/32] clean up --- llmfoundry/layers_registry.py | 11 +++++++++++ llmfoundry/models/layers/ffn.py | 7 ++++--- llmfoundry/models/layers/layer_builders.py | 9 ++++++++- llmfoundry/models/mpt/configuration_mpt.py | 3 ++- llmfoundry/models/mpt/modeling_mpt.py | 7 ++++--- llmfoundry/models/utils/config_moe_args.py | 3 ++- llmfoundry/models/utils/mpt_param_count.py | 10 ++++++---- llmfoundry/registry.py | 4 +++- llmfoundry/utils/config_utils.py | 3 ++- scripts/train/train.py | 3 ++- 10 files changed, 44 insertions(+), 16 deletions(-) diff --git a/llmfoundry/layers_registry.py b/llmfoundry/layers_registry.py index 4d529058db..c2c1c8930a 100644 --- a/llmfoundry/layers_registry.py +++ b/llmfoundry/layers_registry.py @@ -33,8 +33,19 @@ entry_points=True, description=_ffns_with_norm_description) +_ffns_with_megablocks_description = ( + 'The ffns_with_megablocks registry is used to register functions that build ffn layers using MegaBlocks.' + + 'See ffn.py for examples.') +ffns_with_megablocks = create_registry( + 'llmfoundry', + 'ffns_with_megablocks', + generic_type=Callable, + entry_points=True, + description=_ffns_with_megablocks_description) + __all__ = [ 'norms', 'ffns', 'ffns_with_norm', + 'ffns_with_megablocks', ] diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index 45f293f6e3..b6e7f562fb 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -13,7 +13,8 @@ from torch.distributed import ProcessGroup from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Shard -from llmfoundry.layers_registry import ffns, ffns_with_norm +from llmfoundry.layers_registry import (ffns, ffns_with_megablocks, + ffns_with_norm) from llmfoundry.models.layers.dmoe import dMoE from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY @@ -444,5 +445,5 @@ def build_mb_dmoe( ffns_with_norm.register('te_ln_mlp', func=build_te_ln_mlp) if is_megablocks_imported: - ffns.register('mb_moe', func=build_mb_moe) - ffns.register('mb_dmoe', func=build_mb_dmoe) + ffns_with_megablocks.register('mb_moe', func=build_mb_moe) + ffns_with_megablocks.register('mb_dmoe', func=build_mb_dmoe) diff --git a/llmfoundry/models/layers/layer_builders.py b/llmfoundry/models/layers/layer_builders.py index 1f845d1dd4..dc6474f06d 100644 --- a/llmfoundry/models/layers/layer_builders.py +++ b/llmfoundry/models/layers/layer_builders.py @@ -5,7 +5,8 @@ import torch -from llmfoundry.layers_registry import ffns, ffns_with_norm, norms +from llmfoundry.layers_registry import (ffns, ffns_with_megablocks, + ffns_with_norm, norms) from llmfoundry.utils.registry_utils import construct_from_registry @@ -37,6 +38,9 @@ def build_ffn( if name in ffns_with_norm: registry_to_use = ffns_with_norm + if name in ffns_with_megablocks: + registry_to_use = ffns_with_megablocks + kwargs = { 'd_model': d_model, 'expansion_ratio': expansion_ratio, @@ -59,4 +63,7 @@ def _validation_function(maybe_module: Any): if name in ffns_with_norm: result._has_norm = True + if name in ffns_with_megablocks: + result._uses_megablocks = True + return result diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index b6c3fa5e98..9437aaf07d 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -8,6 +8,7 @@ from transformers import PretrainedConfig +from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.models.layers.attention import (check_alibi_support, is_flash_v2_installed) from llmfoundry.models.layers.blocks import attn_config_defaults @@ -290,7 +291,7 @@ def _validate_config(self) -> None: ) elif self.ffn_config['ffn_type'] in ['mptmlp', 'mptglu']: self.ffn_config['fc_type'] = self.fc_type - elif self.ffn_config['ffn_type'] in ['mb_moe', 'mb_dmoe']: + elif self.ffn_config['ffn_type'] in ffns_with_megablocks: self.ffn_config['return_bias'] = False elif self.ffn_config['ffn_type'] == 'te_ln_mlp': self.ffn_config['bias'] = not self.no_bias diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index fd5567e04d..aca2350051 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -20,6 +20,7 @@ from composer.models import HuggingFaceModel from composer.utils import dist +from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.models.layers.attention import is_flash_v2_installed if is_flash_v2_installed(): @@ -324,7 +325,7 @@ def __init__(self, config: MPTConfig): self.emb_drop = nn.Dropout(config.emb_pdrop) self.mb_args = None block_args = config.to_dict() - if block_args['ffn_config']['ffn_type'] in ('mb_moe', 'mb_dmoe'): + if block_args['ffn_config']['ffn_type'] in ffns_with_megablocks: block_args['ffn_config'] = config_moe_args( block_args['ffn_config'], config.d_model, @@ -1026,7 +1027,7 @@ def get_targets(self, batch: Mapping) -> torch.Tensor: return targets def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: - if self.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'): + if self.config.ffn_config['ffn_type'] in ffns_with_megablocks: # Clear MegaBlocks MoE load balancing loss cache try: # Add try/catch to avoid transformers complaining and raising errors from megablocks.layers.moe import clear_load_balancing_loss @@ -1053,7 +1054,7 @@ def loss(self, outputs: CausalLMOutputWithPast, else: loss = losses.sum() / (targets != self.loss_fn.ignore_index).sum() - if self.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'): + if self.config.ffn_config['ffn_type'] in ffns_with_megablocks: # MegaBlocks MoE load balancing loss try: # Add try/catch to avoid transformers complaining and raising errors from megablocks.layers.moe import batched_load_balancing_loss diff --git a/llmfoundry/models/utils/config_moe_args.py b/llmfoundry/models/utils/config_moe_args.py index b69cd18348..4bbb246613 100644 --- a/llmfoundry/models/utils/config_moe_args.py +++ b/llmfoundry/models/utils/config_moe_args.py @@ -9,6 +9,7 @@ from packaging import version from torch import distributed +from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.models.layers.ffn import resolve_ffn_hidden_size @@ -156,7 +157,7 @@ def config_moe_args( Returns: ffn_config (dict): FFN configuration with MoE configured. """ - if ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'): + if ffn_config['ffn_type'] in ffns_with_megablocks: return config_megablocks_moe_args( ffn_config=ffn_config, d_model=d_model, diff --git a/llmfoundry/models/utils/mpt_param_count.py b/llmfoundry/models/utils/mpt_param_count.py index d90929713b..cb1a5c0935 100644 --- a/llmfoundry/models/utils/mpt_param_count.py +++ b/llmfoundry/models/utils/mpt_param_count.py @@ -16,6 +16,8 @@ from torch import Tensor, nn from torch.distributed._tensor import DTensor +from llmfoundry.layers_registry import ffns_with_megablocks + def module_n_params(module: nn.Module) -> int: """Gets the number of parameters in this module excluding child modules. @@ -127,7 +129,7 @@ def megablocks_n_active_params(mpt_model) -> int: # type: ignore def mpt_get_total_params(mpt_model) -> int: # type: ignore - """Calculates the total paramter count of an MPT model. + """Calculates the total parameter count of an MPT model. Note: Must be called before model parameters are sharded by FSDP. @@ -138,14 +140,14 @@ def mpt_get_total_params(mpt_model) -> int: # type: ignore Returns: An int for the total number of parameters in this MPT model. """ - if mpt_model.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'): + if mpt_model.config.ffn_config['ffn_type'] in ffns_with_megablocks: return megablocks_n_total_params(mpt_model) else: return sum(p.numel() for p in mpt_model.parameters()) def mpt_get_active_params(mpt_model) -> int: # type: ignore - """Calculates the total paramter count of an MPT model. + """Calculates the total parameter count of an MPT model. Note: Must be called before model parameters are sharded by FSDP. @@ -156,7 +158,7 @@ def mpt_get_active_params(mpt_model) -> int: # type: ignore Returns: An int for the active number of parameters in this MPT model. """ - if mpt_model.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'): + if mpt_model.config.ffn_config['ffn_type'] in ffns_with_megablocks: params = megablocks_n_active_params(mpt_model) else: params = sum(p.numel() for p in mpt_model.parameters()) diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index acde76a630..60276bd5ea 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -12,7 +12,8 @@ from transformers import PreTrainedTokenizerBase from llmfoundry.interfaces import CallbackWithConfig -from llmfoundry.layers_registry import ffns, ffns_with_norm, norms +from llmfoundry.layers_registry import (ffns, ffns_with_megablocks, + ffns_with_norm, norms) from llmfoundry.utils.registry_utils import create_registry _loggers_description = ( @@ -123,4 +124,5 @@ 'norms', 'ffns', 'ffns_with_norm', + 'ffns_with_megablocks', ] diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index d2c3b733c0..a4fd005c3a 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -11,6 +11,7 @@ from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om +from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.models.utils import init_empty_weights log = logging.getLogger(__name__) @@ -131,7 +132,7 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): # Set ffn_config.device_mesh to fsdp_config.device_mesh if fsdp_config is not None and 'device_mesh' in fsdp_config and 'ffn_config' in model_cfg and model_cfg[ - 'ffn_config'].get('ffn_type', None) in {'mb_moe', 'mb_dmoe'}: + 'ffn_config'].get('ffn_type', None) in ffns_with_megablocks: # Raise ValueError if not using device mesh with MoE expert parallelism if fsdp_config['device_mesh'] is None and model_cfg['ffn_config'].get( 'moe_world_size', 1) > 1: diff --git a/scripts/train/train.py b/scripts/train/train.py index 96066d5a1d..5bcce0038e 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -26,6 +26,7 @@ install() from llmfoundry.callbacks import AsyncEval from llmfoundry.data.dataloader import build_dataloader +from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.utils.builders import (add_metrics_to_eval_loaders, build_algorithm, build_callback, build_composer_model, build_evaluators, @@ -102,7 +103,7 @@ def validate_config(cfg: DictConfig): ) if cfg.model.get('ffn_config', {}).get('ffn_type', - 'mptmlp') in ('mb_moe', 'mb_dmoe'): + 'mptmlp') in ffns_with_megablocks: moe_world_size = cfg.model.get('ffn_config', {}).get('moe_world_size', 1) use_orig_params = cfg.get('fsdp_config', From 6ebd28c6214d9fabeaf710cf877c0967f9f55ccf Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 18:42:12 -0700 Subject: [PATCH 06/32] fix --- llmfoundry/models/layers/blocks.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index cf7a67f10d..d1c22f7f3a 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -72,6 +72,14 @@ def __init__( del kwargs # unused, just to capture any extra args from the config super().__init__() + self.ffn = build_ffn( + d_model=d_model, + expansion_ratio=expansion_ratio, + device=device, + bias=not no_bias, + **ffn_config, + ) + if self.fuse_norm_attn_norm: self.norm_attn_norm = FusedNormAttentionNorm( d_model=d_model, @@ -122,13 +130,6 @@ def __init__( device=device, ) - self.ffn = build_ffn( - d_model=d_model, - expansion_ratio=expansion_ratio, - device=device, - bias=not no_bias, - **ffn_config, - ) self.resid_attn_dropout = nn.Dropout(resid_pdrop) self.resid_ffn_dropout = nn.Dropout(resid_pdrop) self.use_pad_tok_in_ffn = use_pad_tok_in_ffn From 3ecd51e27e0ffed498336aab807c6218188002a9 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 18:45:47 -0700 Subject: [PATCH 07/32] fix --- llmfoundry/models/layers/blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index d1c22f7f3a..1eb57e2055 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -77,7 +77,7 @@ def __init__( expansion_ratio=expansion_ratio, device=device, bias=not no_bias, - **ffn_config, + ffn_kwargs=ffn_config, ) if self.fuse_norm_attn_norm: From ce04066c7a4f2195291157c12357dc8f87f6ffb5 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 18:49:56 -0700 Subject: [PATCH 08/32] fix --- llmfoundry/models/layers/blocks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 1eb57e2055..d1ed06b9d6 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -73,6 +73,7 @@ def __init__( super().__init__() self.ffn = build_ffn( + name=ffn_config['ffn_type'], d_model=d_model, expansion_ratio=expansion_ratio, device=device, From 539fa0afe58fa6c2907c29ace02bd9020095a3aa Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 19:00:27 -0700 Subject: [PATCH 09/32] fix --- llmfoundry/models/layers/ffn.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index 9a35e3dad9..a8fbaef30a 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -317,7 +317,6 @@ def _patch_ffn_mb( expert_parallel_group: ProcessGroup, device_mesh: DeviceMesh, args: 'megablocks.layers.arguments.Arguments', - **kwargs: Any, ): # Attach args to MLP directly for use in param_init_fn ffn.experts.mlp.hidden_size = args.ffn_hidden_size @@ -325,8 +324,6 @@ def _patch_ffn_mb( ffn.experts.mlp.weight_parallel_group = args.weight_parallel_group if moe_world_size > 1: - device_mesh = kwargs['device_mesh'] - expert_mesh = device_mesh['expert_parallel'] expert_placements: List[Placement] = [Shard(0)] # Register in two loops as you cannot overwrite parameters while iterating over named_parameters() @@ -340,7 +337,6 @@ def _patch_ffn_mb( for name, dtensorified_param in dtensorified_params: ffn.experts.mlp.register_parameter(name, dtensorified_param) - device_mesh = kwargs['device_mesh'] if device_mesh.mesh.ndim == 2: submesh = device_mesh['weight_parallel'] elif device_mesh.mesh.ndim == 3: @@ -389,7 +385,6 @@ def build_mb_moe( expert_parallel_group=expert_parallel_group, device_mesh=kwargs['device_mesh'], args=args, - **kwargs, ) return ffn @@ -434,7 +429,6 @@ def build_mb_dmoe( expert_parallel_group=expert_parallel_group, device_mesh=kwargs['device_mesh'], args=args, - **kwargs, ) return ffn From cd7399160ece98ea2f85bb775f6be2b0f0ae7534 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 19:03:50 -0700 Subject: [PATCH 10/32] fix --- llmfoundry/models/layers/blocks.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index d1ed06b9d6..ee4d74c121 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn +from llmfoundry.layers_registry import ffns_with_norm from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY from llmfoundry.models.layers.layer_builders import build_ffn, build_norm @@ -239,17 +240,9 @@ def __init__( ) ffn_type = ffn_config.pop('ffn_type') - self.ffn = build_ffn( - name=ffn_type, - d_model=d_model, - expansion_ratio=expansion_ratio, - device=device, - bias=not no_bias, - ffn_kwargs=ffn_config, - ) self.norm_2 = None - if not getattr(self.ffn, '_has_norm', False): + if not ffn_type in ffns_with_norm: self.norm_2 = build_norm( name=norm_type.lower(), normalized_shape=d_model, From 889ea6c6c6094f08adc5d17a5e791dc6e1894c5b Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 19:04:59 -0700 Subject: [PATCH 11/32] fix --- llmfoundry/models/layers/blocks.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index ee4d74c121..6b44254a05 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -86,7 +86,6 @@ def __init__( self.norm_attn_norm = FusedNormAttentionNorm( d_model=d_model, n_heads=n_heads, - expansion_ratio=expansion_ratio, attn_config=attn_config, ffn_config=ffn_config, fc_type=fc_type, @@ -198,7 +197,6 @@ def __init__( self, d_model: int, n_heads: int, - expansion_ratio: float, attn_config: Optional[Dict] = None, ffn_config: Optional[Dict] = None, fc_type: str = 'torch', From b61d761704f40b05ff2afd3012326fc50e1d9804 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 19:32:54 -0700 Subject: [PATCH 12/32] fix? --- llmfoundry/models/layers/blocks.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 31d9bf6b5f..33b60c42dc 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -73,8 +73,10 @@ def __init__( del kwargs # unused, just to capture any extra args from the config super().__init__() + ffn_type = ffn_config.pop('ffn_type') + self.ffn = build_ffn( - name=ffn_config['ffn_type'], + name=ffn_type, d_model=d_model, expansion_ratio=expansion_ratio, device=device, From 21ea5f06853395537eb1b7de5a8623f4397c2f26 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 19:35:42 -0700 Subject: [PATCH 13/32] debug --- llmfoundry/models/mpt/modeling_mpt.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index aca2350051..c812d9e6b7 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -325,6 +325,7 @@ def __init__(self, config: MPTConfig): self.emb_drop = nn.Dropout(config.emb_pdrop) self.mb_args = None block_args = config.to_dict() + print(block_args['ffn_config']) if block_args['ffn_config']['ffn_type'] in ffns_with_megablocks: block_args['ffn_config'] = config_moe_args( block_args['ffn_config'], @@ -332,6 +333,7 @@ def __init__(self, config: MPTConfig): config.expansion_ratio, config.n_layers, ) + print(block_args['ffn_config']) self.mb_args = block_args['ffn_config'].get('args') self.blocks = nn.ModuleList([ MPTBlock( From 22f4c1470f6ab2da5eb4f900f56f0640a0b8c7cb Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 19:39:12 -0700 Subject: [PATCH 14/32] fix --- llmfoundry/models/layers/blocks.py | 5 +---- llmfoundry/models/mpt/modeling_mpt.py | 2 -- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 33b60c42dc..2b94d35d64 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -201,7 +201,7 @@ def __init__( d_model: int, n_heads: int, attn_config: Optional[Dict] = None, - ffn_config: Optional[Dict] = None, + ffn_type: str = 'mptmlp', fc_type: str = 'torch', resid_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', @@ -211,7 +211,6 @@ def __init__( ): super().__init__() assert attn_config is not None - assert ffn_config is not None assert isinstance(attn_config['attn_type'], str) # necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs @@ -242,8 +241,6 @@ def __init__( }, ) - ffn_type = ffn_config.pop('ffn_type') - self.norm_2 = None if not ffn_type in ffns_with_norm: self.norm_2 = build_norm( diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index c812d9e6b7..aca2350051 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -325,7 +325,6 @@ def __init__(self, config: MPTConfig): self.emb_drop = nn.Dropout(config.emb_pdrop) self.mb_args = None block_args = config.to_dict() - print(block_args['ffn_config']) if block_args['ffn_config']['ffn_type'] in ffns_with_megablocks: block_args['ffn_config'] = config_moe_args( block_args['ffn_config'], @@ -333,7 +332,6 @@ def __init__(self, config: MPTConfig): config.expansion_ratio, config.n_layers, ) - print(block_args['ffn_config']) self.mb_args = block_args['ffn_config'].get('args') self.blocks = nn.ModuleList([ MPTBlock( From c0ac73aba04c02e77d5edb404b2a46727ab27203 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 19:45:56 -0700 Subject: [PATCH 15/32] debug --- llmfoundry/models/layers/blocks.py | 3 ++- llmfoundry/models/mpt/modeling_mpt.py | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 2b94d35d64..5f6fba1803 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -61,6 +61,7 @@ def __init__( use_pad_tok_in_ffn: bool = True, **kwargs: Any, ): + print(ffn_config) if attn_config is None: attn_config = attn_config_defaults @@ -89,7 +90,7 @@ def __init__( d_model=d_model, n_heads=n_heads, attn_config=attn_config, - ffn_config=ffn_config, + ffn_type=ffn_type, fc_type=fc_type, resid_pdrop=resid_pdrop, norm_type=norm_type, diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index aca2350051..3baab79469 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -325,6 +325,7 @@ def __init__(self, config: MPTConfig): self.emb_drop = nn.Dropout(config.emb_pdrop) self.mb_args = None block_args = config.to_dict() + print(block_args['ffn_config'], block_args['ffn_config']['ffn_type']) if block_args['ffn_config']['ffn_type'] in ffns_with_megablocks: block_args['ffn_config'] = config_moe_args( block_args['ffn_config'], @@ -333,6 +334,8 @@ def __init__(self, config: MPTConfig): config.n_layers, ) self.mb_args = block_args['ffn_config'].get('args') + print(block_args['ffn_config'], block_args['ffn_config']['ffn_type']) + self.blocks = nn.ModuleList([ MPTBlock( device=config.init_device, From 7c3d4938a1e31a0a40459a3c71debb13d1437c2a Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 19:58:08 -0700 Subject: [PATCH 16/32] debug --- llmfoundry/models/layers/blocks.py | 2 +- llmfoundry/models/layers/layer_builders.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 5f6fba1803..77f39c2a38 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -74,7 +74,7 @@ def __init__( del kwargs # unused, just to capture any extra args from the config super().__init__() - ffn_type = ffn_config.pop('ffn_type') + ffn_type = ffn_config['ffn_type'] self.ffn = build_ffn( name=ffn_type, diff --git a/llmfoundry/models/layers/layer_builders.py b/llmfoundry/models/layers/layer_builders.py index ed24de358e..1d32b6baf7 100644 --- a/llmfoundry/models/layers/layer_builders.py +++ b/llmfoundry/models/layers/layer_builders.py @@ -35,6 +35,7 @@ def build_ffn( bias: bool, ffn_kwargs: Dict[str, Any], ): + registry_to_use = ffns if name in ffns_with_norm: registry_to_use = ffns_with_norm @@ -47,7 +48,7 @@ def build_ffn( 'expansion_ratio': expansion_ratio, 'device': device, 'bias': bias, - **ffn_kwargs, + **{k:v for k,v in ffn_kwargs.items() if k != 'ffn_type'}, } def _validation_function(maybe_module: Any): From 59380489c9953a8689119483801ed5d405f7dff6 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 20:00:47 -0700 Subject: [PATCH 17/32] pc and remove prints --- llmfoundry/models/layers/blocks.py | 1 - llmfoundry/models/layers/layer_builders.py | 2 +- llmfoundry/models/mpt/modeling_mpt.py | 2 -- 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 77f39c2a38..c57e093860 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -61,7 +61,6 @@ def __init__( use_pad_tok_in_ffn: bool = True, **kwargs: Any, ): - print(ffn_config) if attn_config is None: attn_config = attn_config_defaults diff --git a/llmfoundry/models/layers/layer_builders.py b/llmfoundry/models/layers/layer_builders.py index 1d32b6baf7..425fcaf862 100644 --- a/llmfoundry/models/layers/layer_builders.py +++ b/llmfoundry/models/layers/layer_builders.py @@ -48,7 +48,7 @@ def build_ffn( 'expansion_ratio': expansion_ratio, 'device': device, 'bias': bias, - **{k:v for k,v in ffn_kwargs.items() if k != 'ffn_type'}, + **{k: v for k, v in ffn_kwargs.items() if k != 'ffn_type'}, } def _validation_function(maybe_module: Any): diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 3baab79469..c19ab753f0 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -325,7 +325,6 @@ def __init__(self, config: MPTConfig): self.emb_drop = nn.Dropout(config.emb_pdrop) self.mb_args = None block_args = config.to_dict() - print(block_args['ffn_config'], block_args['ffn_config']['ffn_type']) if block_args['ffn_config']['ffn_type'] in ffns_with_megablocks: block_args['ffn_config'] = config_moe_args( block_args['ffn_config'], @@ -334,7 +333,6 @@ def __init__(self, config: MPTConfig): config.n_layers, ) self.mb_args = block_args['ffn_config'].get('args') - print(block_args['ffn_config'], block_args['ffn_config']['ffn_type']) self.blocks = nn.ModuleList([ MPTBlock( From 92a9ded2f7f99a6acfd086d7e185ad03761a0098 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 20:12:02 -0700 Subject: [PATCH 18/32] fix --- llmfoundry/models/layers/blocks.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index c57e093860..0465226820 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -74,22 +74,14 @@ def __init__( super().__init__() ffn_type = ffn_config['ffn_type'] - - self.ffn = build_ffn( - name=ffn_type, - d_model=d_model, - expansion_ratio=expansion_ratio, - device=device, - bias=not no_bias, - ffn_kwargs=ffn_config, - ) + ffn_has_norm = not ffn_type in ffns_with_norm if self.fuse_norm_attn_norm: self.norm_attn_norm = FusedNormAttentionNorm( d_model=d_model, n_heads=n_heads, attn_config=attn_config, - ffn_type=ffn_type, + ffn_has_norm=ffn_has_norm, fc_type=fc_type, resid_pdrop=resid_pdrop, norm_type=norm_type, @@ -127,13 +119,22 @@ def __init__( }, ) self.norm_2 = None - if not getattr(self.ffn, '_has_norm', False): + if not ffn_has_norm: self.norm_2 = build_norm( name=norm_type.lower(), normalized_shape=d_model, device=device, ) + self.ffn = build_ffn( + name=ffn_type, + d_model=d_model, + expansion_ratio=expansion_ratio, + device=device, + bias=not no_bias, + ffn_kwargs=ffn_config, + ) + self.resid_attn_dropout = nn.Dropout(resid_pdrop) self.resid_ffn_dropout = nn.Dropout(resid_pdrop) self.use_pad_tok_in_ffn = use_pad_tok_in_ffn @@ -201,7 +202,7 @@ def __init__( d_model: int, n_heads: int, attn_config: Optional[Dict] = None, - ffn_type: str = 'mptmlp', + ffn_has_norm: bool = False, fc_type: str = 'torch', resid_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', @@ -242,7 +243,7 @@ def __init__( ) self.norm_2 = None - if not ffn_type in ffns_with_norm: + if not ffn_has_norm: self.norm_2 = build_norm( name=norm_type.lower(), normalized_shape=d_model, From dec23afaed20d0362ac7ff721b4601921ea85e3e Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 20:20:00 -0700 Subject: [PATCH 19/32] fix --- llmfoundry/models/layers/blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 0465226820..40f349368f 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -74,7 +74,7 @@ def __init__( super().__init__() ffn_type = ffn_config['ffn_type'] - ffn_has_norm = not ffn_type in ffns_with_norm + ffn_has_norm = ffn_type in ffns_with_norm if self.fuse_norm_attn_norm: self.norm_attn_norm = FusedNormAttentionNorm( From b0d2849c034b326f9d697b144fece8b9de12db36 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 12 Apr 2024 04:25:28 +0000 Subject: [PATCH 20/32] fix --- llmfoundry/models/mpt/modeling_mpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index c19ab753f0..124ab3db3e 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -48,7 +48,6 @@ build_attn_bias, gen_slopes) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding -from llmfoundry.models.layers.layer_builders import build_ffn as build_ffn from llmfoundry.models.layers.layer_builders import build_norm from llmfoundry.models.mpt.configuration_mpt import MPTConfig from llmfoundry.models.utils.config_moe_args import config_moe_args @@ -65,6 +64,7 @@ generic_param_init_fn_, # type: ignore (see note) MODEL_INIT_REGISTRY, ) +from llmfoundry.models.layers.ffn import resolve_ffn_act_fn # type: ignore (see note) from llmfoundry.models.utils.act_ckpt import (pass_on_block_idx, build_act_ckpt_mod_to_blocks, From 0a5026e3887a23c7a2cdb58d6f03632cd805ba0c Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 12 Apr 2024 04:34:42 +0000 Subject: [PATCH 21/32] fix tests --- tests/models/layers/test_dmoe.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/models/layers/test_dmoe.py b/tests/models/layers/test_dmoe.py index 9c15745793..c8e7ec3e67 100644 --- a/tests/models/layers/test_dmoe.py +++ b/tests/models/layers/test_dmoe.py @@ -239,6 +239,10 @@ def test_fwd_equal_dmoe(seqlen: int, precision: str, mlp_type: str): torch_dmoe_config = copy.deepcopy(mb_dmoe_config) torch_dmoe_config.ffn_config['ffn_type'] = 'torch_dmoe' + del torch_dmoe_config.ffn_config['moe_world_size'] + del torch_dmoe_config.ffn_config['fc_type'] + del torch_dmoe_config.ffn_config['moe_loss_weight'] + del torch_dmoe_config.ffn_config['return_bias'] mb_dmoe_model = MPTForCausalLM(mb_dmoe_config).to(device=device, dtype=dtype) From e7ac0a93eec6018741ed28248cfc2451dbe027ac Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 21:54:25 -0700 Subject: [PATCH 22/32] unused kwarg --- llmfoundry/models/layers/ffn.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index a8fbaef30a..79749ad663 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -256,6 +256,9 @@ def build_torch_dmoe( moe_normalize_expert_weights = kwargs.pop('moe_normalize_expert_weights') uniform_expert_assignment = kwargs.pop('uniform_expert_assignment') + fc_type = kwargs.pop('fc_type', 'torch') + del fc_type # Unused + if len(kwargs) > 0: raise ValueError(f'Invalid arguments to torch dmoe: {kwargs}.') From b9a5abc36ec74c783c87dd6373d0b951f1f91933 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 12 Apr 2024 05:02:14 +0000 Subject: [PATCH 23/32] logs --- scripts/train/train.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/scripts/train/train.py b/scripts/train/train.py index 5bcce0038e..96226e8c7c 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -102,12 +102,18 @@ def validate_config(cfg: DictConfig): '`load_in_8bit` is only supported for evaluation rather than training.' ) + print('in validate') + print(cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp')) + print(cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') in ffns_with_megablocks) if cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') in ffns_with_megablocks: + print('inside') moe_world_size = cfg.model.get('ffn_config', {}).get('moe_world_size', 1) + print(moe_world_size) use_orig_params = cfg.get('fsdp_config', {}).get('use_orig_params', True) + print(use_orig_params) if moe_world_size > 1 and not use_orig_params: raise ValueError( f'MoEs with expert parallelism (moe_world_size {moe_world_size} > 1) require `use_orig_params=True`.' From ea56fa99421af06a8816b1b78999053fc55b68d9 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 22:12:04 -0700 Subject: [PATCH 24/32] fix cpu test --- tests/a_scripts/train/test_train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/a_scripts/train/test_train.py b/tests/a_scripts/train/test_train.py index ff885ac735..fe58a44459 100644 --- a/tests/a_scripts/train/test_train.py +++ b/tests/a_scripts/train/test_train.py @@ -147,6 +147,7 @@ def test_train_multi_eval(tmp_path: pathlib.Path): tuple) +@pytest.mark.gpu def test_validate_config(): conf_path: str = os.path.join( REPO_DIR, From bce0ebb3d43d782622d46d25099d4f11cc7ed5e2 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 22:19:49 -0700 Subject: [PATCH 25/32] pc --- scripts/train/train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/train/train.py b/scripts/train/train.py index 96226e8c7c..b923fbc7cd 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -104,7 +104,9 @@ def validate_config(cfg: DictConfig): print('in validate') print(cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp')) - print(cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') in ffns_with_megablocks) + print( + cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') in + ffns_with_megablocks) if cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') in ffns_with_megablocks: print('inside') From f72db7859e0a43e8d6df92b1073b202db01d2822 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 22:21:25 -0700 Subject: [PATCH 26/32] remove prints --- scripts/train/train.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/scripts/train/train.py b/scripts/train/train.py index b923fbc7cd..5bcce0038e 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -102,20 +102,12 @@ def validate_config(cfg: DictConfig): '`load_in_8bit` is only supported for evaluation rather than training.' ) - print('in validate') - print(cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp')) - print( - cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') in - ffns_with_megablocks) if cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') in ffns_with_megablocks: - print('inside') moe_world_size = cfg.model.get('ffn_config', {}).get('moe_world_size', 1) - print(moe_world_size) use_orig_params = cfg.get('fsdp_config', {}).get('use_orig_params', True) - print(use_orig_params) if moe_world_size > 1 and not use_orig_params: raise ValueError( f'MoEs with expert parallelism (moe_world_size {moe_world_size} > 1) require `use_orig_params=True`.' From 9e2f72823b44d8727bd54f6ce9a3556ab31a156a Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 22:40:39 -0700 Subject: [PATCH 27/32] logs --- llmfoundry/models/utils/param_init_fns.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index bd409dee36..effad60544 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -157,6 +157,13 @@ def generic_param_init_fn_( emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, **kwargs: Any, ) -> None: + + print('in init') + print(type(module)) + print(isinstance(module, GLU)) + print(isinstance(module, MLP)) + print(GLU) + print(MLP) del kwargs # unused, just to capture any extra args from the config # enable user to divide _is_residual weights by From d7192af2b2ea9328545000c7d18cbfc2aaabcf0a Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 22:56:59 -0700 Subject: [PATCH 28/32] pc --- llmfoundry/models/utils/param_init_fns.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index effad60544..495bb14a69 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -157,7 +157,7 @@ def generic_param_init_fn_( emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, **kwargs: Any, ) -> None: - + print('in init') print(type(module)) print(isinstance(module, GLU)) From 53e2614e686c32469089656bd4c2779f58d8420e Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 23:48:39 -0700 Subject: [PATCH 29/32] maybe --- tests/fixtures/autouse.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/fixtures/autouse.py b/tests/fixtures/autouse.py index ccbe1b69f7..cb6bd10ce2 100644 --- a/tests/fixtures/autouse.py +++ b/tests/fixtures/autouse.py @@ -1,6 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import copy import gc import os import sys @@ -13,6 +14,15 @@ REPO_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) sys.path.append(REPO_DIR) +@pytest.fixture(autouse=True) +def save_registry(): + from catalogue import REGISTRY + # Save it + saved_registry = copy.deepcopy(REGISTRY) + # Yield + yield + # Restore it + REGISTRY.update(saved_registry) @pytest.fixture(autouse=True) def initialize_dist(request: pytest.FixtureRequest): From 3fd8086462bdad90eb84d9c720b5b595f477b208 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 23:49:08 -0700 Subject: [PATCH 30/32] pc --- tests/fixtures/autouse.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/fixtures/autouse.py b/tests/fixtures/autouse.py index cb6bd10ce2..04c0812aeb 100644 --- a/tests/fixtures/autouse.py +++ b/tests/fixtures/autouse.py @@ -14,9 +14,11 @@ REPO_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) sys.path.append(REPO_DIR) + @pytest.fixture(autouse=True) def save_registry(): from catalogue import REGISTRY + # Save it saved_registry = copy.deepcopy(REGISTRY) # Yield @@ -24,6 +26,7 @@ def save_registry(): # Restore it REGISTRY.update(saved_registry) + @pytest.fixture(autouse=True) def initialize_dist(request: pytest.FixtureRequest): """Initialize the default PyTorch distributed process group for tests.""" From f8d4c8f7ee82eec5c51c72f1534b9a9317cc4cf7 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 12 Apr 2024 00:09:13 -0700 Subject: [PATCH 31/32] maybe fix --- llmfoundry/models/utils/param_init_fns.py | 7 ------- llmfoundry/utils/registry_utils.py | 12 ++++++++++++ tests/fixtures/autouse.py | 15 +++++---------- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index 495bb14a69..bd409dee36 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -157,13 +157,6 @@ def generic_param_init_fn_( emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, **kwargs: Any, ) -> None: - - print('in init') - print(type(module)) - print(isinstance(module, GLU)) - print(isinstance(module, MLP)) - print(GLU) - print(MLP) del kwargs # unused, just to capture any extra args from the config # enable user to divide _is_residual weights by diff --git a/llmfoundry/utils/registry_utils.py b/llmfoundry/utils/registry_utils.py index 83348a7fd6..0eeefbae74 100644 --- a/llmfoundry/utils/registry_utils.py +++ b/llmfoundry/utils/registry_utils.py @@ -1,9 +1,11 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import copy import functools import importlib.util import os +from contextlib import contextmanager from pathlib import Path from types import ModuleType from typing import (Any, Callable, Dict, Generic, Optional, Sequence, Type, @@ -174,3 +176,13 @@ def import_file(loc: Union[str, Path]) -> ModuleType: except Exception as e: raise RuntimeError(f'Error executing {loc}') from e return module + + +@contextmanager +def save_registry(): + """Save the registry state and restore after the context manager exits.""" + saved_registry_state = copy.deepcopy(catalogue.REGISTRY) + + yield + + catalogue.REGISTRY = saved_registry_state diff --git a/tests/fixtures/autouse.py b/tests/fixtures/autouse.py index 04c0812aeb..16e3f8ad6f 100644 --- a/tests/fixtures/autouse.py +++ b/tests/fixtures/autouse.py @@ -1,7 +1,6 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import copy import gc import os import sys @@ -10,21 +9,17 @@ import torch from composer.utils import dist, get_device, reproducibility +from llmfoundry.utils.registry_utils import save_registry + # Add llm-foundry repo root to path so we can import scripts in the tests REPO_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) sys.path.append(REPO_DIR) @pytest.fixture(autouse=True) -def save_registry(): - from catalogue import REGISTRY - - # Save it - saved_registry = copy.deepcopy(REGISTRY) - # Yield - yield - # Restore it - REGISTRY.update(saved_registry) +def save_registry_fixture(): + with save_registry(): + yield @pytest.fixture(autouse=True) From d1fda2d3bbe09a16506b774787979d09d4be34d6 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 12 Apr 2024 13:20:21 -0700 Subject: [PATCH 32/32] fix pyright --- llmfoundry/models/layers/dmoe.py | 51 ++++++++++++++++++++++---------- llmfoundry/models/layers/ffn.py | 2 +- 2 files changed, 37 insertions(+), 16 deletions(-) diff --git a/llmfoundry/models/layers/dmoe.py b/llmfoundry/models/layers/dmoe.py index 1a981b61c5..19cd67b8aa 100644 --- a/llmfoundry/models/layers/dmoe.py +++ b/llmfoundry/models/layers/dmoe.py @@ -1,7 +1,7 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import Callable +from typing import Callable, Optional import torch @@ -24,7 +24,8 @@ class LearnedRouter(torch.nn.Module): def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int, moe_jitter_eps: float, moe_normalize_expert_weights: bool, - uniform_expert_assignment: bool, device: torch.device) -> None: + uniform_expert_assignment: bool, + device: Optional[torch.device]) -> None: super().__init__() self.hidden_size: int = hidden_size self.moe_num_experts: int = moe_num_experts @@ -84,7 +85,7 @@ def __init__( ffn_hidden_size: int, moe_num_experts: int, activation_fn: Callable, - device: torch.device, + device: Optional[torch.device], ) -> None: super().__init__() @@ -117,9 +118,14 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: class GLU(torch.nn.Module): - def __init__(self, hidden_size: int, ffn_hidden_size: int, - moe_num_experts: int, activation_fn: Callable, - device: torch.device): + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + moe_num_experts: int, + activation_fn: Callable, + device: Optional[torch.device], + ): super().__init__() self.hidden_size = hidden_size self.ffn_hidden_size = ffn_hidden_size @@ -157,9 +163,16 @@ def forward(self, x: torch.Tensor, expert_idx: torch.Tensor): class DroplessMLP(torch.nn.Module): - def __init__(self, hidden_size: int, ffn_hidden_size: int, mlp_type: str, - moe_num_experts: int, activation_fn: Callable, bias: bool, - device: torch.device): + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + mlp_type: str, + moe_num_experts: int, + activation_fn: Callable, + bias: bool, + device: Optional[torch.device], + ): super().__init__() self.moe_num_experts = moe_num_experts @@ -209,12 +222,20 @@ def forward(self, x: torch.Tensor, scores: torch.Tensor, class dMoE(torch.nn.Module): - def __init__(self, hidden_size: int, ffn_hidden_size: int, - moe_num_experts: int, moe_top_k: int, mlp_type: str, - activation_fn: Callable, moe_jitter_eps: float, - moe_normalize_expert_weights: bool, - uniform_expert_assignment: bool, bias: bool, - device: torch.device): + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + moe_num_experts: int, + moe_top_k: int, + mlp_type: str, + activation_fn: Callable, + moe_jitter_eps: float, + moe_normalize_expert_weights: bool, + uniform_expert_assignment: bool, + bias: bool, + device: Optional[torch.device], + ): super().__init__() # Token router. diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index 79749ad663..fb663b4c3c 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -274,7 +274,7 @@ def build_torch_dmoe( activation_fn=resolve_ffn_act_fn(ffn_act_fn), moe_normalize_expert_weights=moe_normalize_expert_weights, uniform_expert_assignment=uniform_expert_assignment, - device=device, # pyright: ignore[reportGeneralTypeIssues] + device=torch.device(device) if device is not None else None, )