diff --git a/llmfoundry/layers_registry.py b/llmfoundry/layers_registry.py index 9c7dabe128..9686e76af7 100644 --- a/llmfoundry/layers_registry.py +++ b/llmfoundry/layers_registry.py @@ -14,7 +14,18 @@ generic_type=Type[torch.nn.Module], entry_points=True, description=_norm_description) +_fc_description = ( + 'The fully connected layers registry is used to register classes that implement fully connected layers (i.e. torch.nn.Linear).' + + + 'These classes should take in_features and out_features in as args, at a minimum.' +) +fcs = create_registry('llmfoundry', + 'fcs', + generic_type=Type[torch.nn.Module], + entry_points=True, + description=_fc_description) __all__ = [ 'norms', + 'fcs', ] diff --git a/llmfoundry/models/layers/__init__.py b/llmfoundry/models/layers/__init__.py index 262f190b47..7328a00757 100644 --- a/llmfoundry/models/layers/__init__.py +++ b/llmfoundry/models/layers/__init__.py @@ -7,7 +7,7 @@ flash_attn_fn, scaled_multihead_dot_product_attention) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding -from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY +from llmfoundry.models.layers.fc import * from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn from llmfoundry.models.layers.norm import LPLayerNorm @@ -24,7 +24,6 @@ 'MPTMLP', 'MPTBlock', 'LPLayerNorm', - 'FC_CLASS_REGISTRY', 'SharedEmbedding', 'FFN_CLASS_REGISTRY', 'build_ffn', diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index c24b3d4afa..331c037ce0 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -14,8 +14,7 @@ from packaging import version from torch import nn -from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY -from llmfoundry.models.layers.layer_builders import build_norm +from llmfoundry.models.layers.layer_builders import build_fc, build_norm def is_flash_v2_installed(v2_version: str = '2.0.0'): @@ -406,10 +405,11 @@ def __init__( 'bias': bias, } fc_kwargs['device'] = device - self.Wqkv = FC_CLASS_REGISTRY[fc_type]( - self.d_model, - self.d_model + 2 * self.kv_n_heads * self.head_dim, - **fc_kwargs, + self.Wqkv = build_fc( + name=fc_type, + in_features=self.d_model, + out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim, + fc_kwargs=fc_kwargs, ) # for param init fn; enables shape based init of fused layers fuse_splits = [ @@ -440,10 +440,11 @@ def __init__( else: raise ValueError(f'{attn_impl=} is an invalid setting.') - self.out_proj = FC_CLASS_REGISTRY[fc_type]( - self.d_model, - self.d_model, - **fc_kwargs, + self.out_proj = build_fc( + name=fc_type, + in_features=self.d_model, + out_features=self.d_model, + fc_kwargs=fc_kwargs, ) self.out_proj._is_residual = True diff --git a/llmfoundry/models/layers/fc.py b/llmfoundry/models/layers/fc.py index b85bc133bd..8650e4966f 100644 --- a/llmfoundry/models/layers/fc.py +++ b/llmfoundry/models/layers/fc.py @@ -3,12 +3,12 @@ from torch import nn -FC_CLASS_REGISTRY = { - 'torch': nn.Linear, -} +from llmfoundry.layers_registry import fcs + +fcs.register('torch', func=nn.Linear) try: import transformer_engine.pytorch as te - FC_CLASS_REGISTRY['te'] = te.Linear + fcs.register('te', func=te.Linear) except: pass diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index 48d3d8c267..f0b499875a 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -13,7 +13,7 @@ from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Shard from llmfoundry.models.layers.dmoe import dMoE -from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY +from llmfoundry.models.layers.layer_builders import build_fc try: import transformer_engine.pytorch as te @@ -52,7 +52,7 @@ def resolve_ffn_act_fn( config = deepcopy(config) name = config.pop('name') if not hasattr(torch.nn.functional, name): - raise ValueError(f'Unrecognised activation function name ({name}).') + raise ValueError(f'Unrecognized activation function name ({name}).') act = getattr(torch.nn.functional, name) return partial(act, **config) @@ -121,16 +121,18 @@ def __init__( self.fc_kwargs['device'] = device - self.up_proj = FC_CLASS_REGISTRY[fc_type]( - d_model, - ffn_hidden_size, - **self.fc_kwargs, + self.up_proj = build_fc( + name=fc_type, + in_features=d_model, + out_features=ffn_hidden_size, + fc_kwargs=self.fc_kwargs, ) self.act = act_fn - self.down_proj = FC_CLASS_REGISTRY[fc_type]( - ffn_hidden_size, - d_model, - **self.fc_kwargs, + self.down_proj = build_fc( + name=fc_type, + in_features=ffn_hidden_size, + out_features=d_model, + fc_kwargs=self.fc_kwargs, ) self.down_proj._is_residual = True @@ -159,10 +161,11 @@ def __init__( device=device, bias=bias, ) - self.gate_proj = FC_CLASS_REGISTRY[fc_type]( - d_model, - self.up_proj.out_features, - **self.fc_kwargs, + self.gate_proj = build_fc( + name=fc_type, + in_features=d_model, + out_features=self.up_proj.out_features, + fc_kwargs=self.fc_kwargs, ) def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/llmfoundry/models/layers/layer_builders.py b/llmfoundry/models/layers/layer_builders.py index 23f5b89668..8244089115 100644 --- a/llmfoundry/models/layers/layer_builders.py +++ b/llmfoundry/models/layers/layer_builders.py @@ -1,11 +1,11 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch -from llmfoundry.layers_registry import norms +from llmfoundry.layers_registry import fcs, norms from llmfoundry.utils.registry_utils import construct_from_registry @@ -23,3 +23,21 @@ def build_norm( registry=norms, pre_validation_function=torch.nn.Module, kwargs=kwargs) + + +def build_fc( + name: str, + in_features: int, + out_features: int, + fc_kwargs: Dict[str, Any], +): + kwargs = { + 'in_features': in_features, + 'out_features': out_features, + **fc_kwargs, + } + + return construct_from_registry(name=name, + registry=fcs, + pre_validation_function=torch.nn.Module, + kwargs=kwargs) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 8383d33ec0..4b98fa611d 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -16,11 +16,10 @@ # HuggingFace can detect all the needed files to copy into its modules folder. # Otherwise, certain modules are missing. # isort: off -from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY # type: ignore (see note) from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note) from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note) +from llmfoundry.models.layers.layer_builders import build_norm, build_fc # type: ignore (see note) from llmfoundry.models.layers.dmoe import dMoE # type: ignore (see note) -from llmfoundry.models.layers.layer_builders import build_norm # type: ignore (see note) from llmfoundry.layers_registry import norms # type: ignore (see note) from llmfoundry.utils.registry_utils import construct_from_registry # type: ignore (see note) diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index 16376de451..bd409dee36 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -12,9 +12,8 @@ from torch import nn from torch.distributed._tensor import DTensor -from llmfoundry.layers_registry import norms +from llmfoundry.layers_registry import fcs, norms from llmfoundry.models.layers.dmoe import GLU, MLP -from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY try: import transformer_engine.pytorch as te @@ -182,7 +181,7 @@ def generic_param_init_fn_( f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}' ) - if isinstance(module, tuple(set(FC_CLASS_REGISTRY.values()))): + if isinstance(module, tuple(set([fcs.get(n) for n in fcs.get_all()]))): # Linear if hasattr(module, '_fused'): fused_init_helper_(module, init_fn_) diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index 424075da3b..7b605808be 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -12,7 +12,7 @@ from transformers import PreTrainedTokenizerBase from llmfoundry.interfaces import CallbackWithConfig -from llmfoundry.layers_registry import norms +from llmfoundry.layers_registry import fcs, norms from llmfoundry.utils.registry_utils import create_registry _loggers_description = ( @@ -121,4 +121,5 @@ 'metrics', 'dataloaders', 'norms', + 'fcs', ] diff --git a/tests/test_registry.py b/tests/test_registry.py index c93c7c9749..6fa0f16b49 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -31,6 +31,7 @@ def test_expected_registries_exist(): 'metrics', 'models', 'norms', + 'fcs', } assert existing_registries == expected_registry_names