From 6842520562abefe75ab121428ca4507b65ba39f5 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 3 Apr 2024 21:23:36 -0700 Subject: [PATCH 1/4] fc registry --- llmfoundry/layers_registry.py | 8 ++++++ llmfoundry/models/layers/__init__.py | 2 -- llmfoundry/models/layers/attention.py | 21 ++++++++-------- llmfoundry/models/layers/fc.py | 8 +++--- llmfoundry/models/layers/ffn.py | 29 ++++++++++++---------- llmfoundry/models/layers/layer_builders.py | 22 ++++++++++++++-- llmfoundry/models/mpt/configuration_mpt.py | 3 +-- llmfoundry/models/utils/param_init_fns.py | 5 ++-- 8 files changed, 62 insertions(+), 36 deletions(-) diff --git a/llmfoundry/layers_registry.py b/llmfoundry/layers_registry.py index 9c7dabe128..6937994c0c 100644 --- a/llmfoundry/layers_registry.py +++ b/llmfoundry/layers_registry.py @@ -15,6 +15,14 @@ entry_points=True, description=_norm_description) +_fc_description = """The fully connected layers registry is used to register classes that implement fully connected layers.""" +fcs = create_registry('llmfoundry', + 'fcs', + generic_type=Type[torch.nn.Module], + entry_points=True, + description=_fc_description) + __all__ = [ 'norms', + 'fcs', ] diff --git a/llmfoundry/models/layers/__init__.py b/llmfoundry/models/layers/__init__.py index 262f190b47..41dff4c770 100644 --- a/llmfoundry/models/layers/__init__.py +++ b/llmfoundry/models/layers/__init__.py @@ -7,7 +7,6 @@ flash_attn_fn, scaled_multihead_dot_product_attention) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding -from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn from llmfoundry.models.layers.norm import LPLayerNorm @@ -24,7 +23,6 @@ 'MPTMLP', 'MPTBlock', 'LPLayerNorm', - 'FC_CLASS_REGISTRY', 'SharedEmbedding', 'FFN_CLASS_REGISTRY', 'build_ffn', diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index c24b3d4afa..331c037ce0 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -14,8 +14,7 @@ from packaging import version from torch import nn -from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY -from llmfoundry.models.layers.layer_builders import build_norm +from llmfoundry.models.layers.layer_builders import build_fc, build_norm def is_flash_v2_installed(v2_version: str = '2.0.0'): @@ -406,10 +405,11 @@ def __init__( 'bias': bias, } fc_kwargs['device'] = device - self.Wqkv = FC_CLASS_REGISTRY[fc_type]( - self.d_model, - self.d_model + 2 * self.kv_n_heads * self.head_dim, - **fc_kwargs, + self.Wqkv = build_fc( + name=fc_type, + in_features=self.d_model, + out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim, + fc_kwargs=fc_kwargs, ) # for param init fn; enables shape based init of fused layers fuse_splits = [ @@ -440,10 +440,11 @@ def __init__( else: raise ValueError(f'{attn_impl=} is an invalid setting.') - self.out_proj = FC_CLASS_REGISTRY[fc_type]( - self.d_model, - self.d_model, - **fc_kwargs, + self.out_proj = build_fc( + name=fc_type, + in_features=self.d_model, + out_features=self.d_model, + fc_kwargs=fc_kwargs, ) self.out_proj._is_residual = True diff --git a/llmfoundry/models/layers/fc.py b/llmfoundry/models/layers/fc.py index b85bc133bd..8650e4966f 100644 --- a/llmfoundry/models/layers/fc.py +++ b/llmfoundry/models/layers/fc.py @@ -3,12 +3,12 @@ from torch import nn -FC_CLASS_REGISTRY = { - 'torch': nn.Linear, -} +from llmfoundry.layers_registry import fcs + +fcs.register('torch', func=nn.Linear) try: import transformer_engine.pytorch as te - FC_CLASS_REGISTRY['te'] = te.Linear + fcs.register('te', func=te.Linear) except: pass diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index 9389cf385f..23ab58b2c6 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -11,7 +11,7 @@ import torch import torch.nn as nn -from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY +from llmfoundry.models.layers.layer_builders import build_fc try: import transformer_engine.pytorch as te @@ -100,16 +100,18 @@ def __init__( self.fc_kwargs['device'] = device - self.up_proj = FC_CLASS_REGISTRY[fc_type]( - d_model, - ffn_hidden_size, - **self.fc_kwargs, + self.up_proj = build_fc( + name=fc_type, + in_features=d_model, + out_features=ffn_hidden_size, + fc_kwargs=self.fc_kwargs, ) self.act = act_fn - self.down_proj = FC_CLASS_REGISTRY[fc_type]( - ffn_hidden_size, - d_model, - **self.fc_kwargs, + self.down_proj = build_fc( + name=fc_type, + in_features=ffn_hidden_size, + out_features=d_model, + fc_kwargs=self.fc_kwargs, ) self.down_proj._is_residual = True @@ -138,10 +140,11 @@ def __init__( device=device, bias=bias, ) - self.gate_proj = FC_CLASS_REGISTRY[fc_type]( - d_model, - self.up_proj.out_features, - **self.fc_kwargs, + self.gate_proj = build_fc( + name=fc_type, + in_features=d_model, + out_features=self.up_proj.out_features, + fc_kwargs=self.fc_kwargs, ) @torch.compile diff --git a/llmfoundry/models/layers/layer_builders.py b/llmfoundry/models/layers/layer_builders.py index 23f5b89668..8244089115 100644 --- a/llmfoundry/models/layers/layer_builders.py +++ b/llmfoundry/models/layers/layer_builders.py @@ -1,11 +1,11 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch -from llmfoundry.layers_registry import norms +from llmfoundry.layers_registry import fcs, norms from llmfoundry.utils.registry_utils import construct_from_registry @@ -23,3 +23,21 @@ def build_norm( registry=norms, pre_validation_function=torch.nn.Module, kwargs=kwargs) + + +def build_fc( + name: str, + in_features: int, + out_features: int, + fc_kwargs: Dict[str, Any], +): + kwargs = { + 'in_features': in_features, + 'out_features': out_features, + **fc_kwargs, + } + + return construct_from_registry(name=name, + registry=fcs, + pre_validation_function=torch.nn.Module, + kwargs=kwargs) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 2f58ea312e..f4780a68ce 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -16,10 +16,9 @@ # HuggingFace can detect all the needed files to copy into its modules folder. # Otherwise, certain modules are missing. # isort: off -from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY # type: ignore (see note) from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note) from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note) -from llmfoundry.models.layers.layer_builders import build_norm # type: ignore (see note) +from llmfoundry.models.layers.layer_builders import build_norm, build_fc # type: ignore (see note) from llmfoundry.layers_registry import norms # type: ignore (see note) from llmfoundry.utils.registry_utils import construct_from_registry # type: ignore (see note) diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index 35dc88a408..009a202072 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -10,8 +10,7 @@ import torch from torch import nn -from llmfoundry.layers_registry import norms -from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY +from llmfoundry.layers_registry import fcs, norms try: import transformer_engine.pytorch as te @@ -87,7 +86,7 @@ def generic_param_init_fn_( f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}' ) - if isinstance(module, tuple(set(FC_CLASS_REGISTRY.values()))): + if isinstance(module, tuple(set([fcs.get(n) for n in fcs.get_all()]))): # Linear if hasattr(module, '_fused'): fused_init_helper_(module, init_fn_) From ca51e12fd6bae54cd2fc6456f96af58bc4b76eec Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 3 Apr 2024 21:29:15 -0700 Subject: [PATCH 2/4] fix --- llmfoundry/models/layers/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llmfoundry/models/layers/__init__.py b/llmfoundry/models/layers/__init__.py index 41dff4c770..7328a00757 100644 --- a/llmfoundry/models/layers/__init__.py +++ b/llmfoundry/models/layers/__init__.py @@ -7,6 +7,7 @@ flash_attn_fn, scaled_multihead_dot_product_attention) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding +from llmfoundry.models.layers.fc import * from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn from llmfoundry.models.layers.norm import LPLayerNorm From 6016dac3c44899f75cc39b3b56835028ad145174 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 4 Apr 2024 18:53:25 -0700 Subject: [PATCH 3/4] clean up --- llmfoundry/layers_registry.py | 7 +++++-- llmfoundry/registry.py | 3 ++- tests/test_registry.py | 1 + 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/llmfoundry/layers_registry.py b/llmfoundry/layers_registry.py index 6937994c0c..9686e76af7 100644 --- a/llmfoundry/layers_registry.py +++ b/llmfoundry/layers_registry.py @@ -14,8 +14,11 @@ generic_type=Type[torch.nn.Module], entry_points=True, description=_norm_description) - -_fc_description = """The fully connected layers registry is used to register classes that implement fully connected layers.""" +_fc_description = ( + 'The fully connected layers registry is used to register classes that implement fully connected layers (i.e. torch.nn.Linear).' + + + 'These classes should take in_features and out_features in as args, at a minimum.' +) fcs = create_registry('llmfoundry', 'fcs', generic_type=Type[torch.nn.Module], diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index 424075da3b..7b605808be 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -12,7 +12,7 @@ from transformers import PreTrainedTokenizerBase from llmfoundry.interfaces import CallbackWithConfig -from llmfoundry.layers_registry import norms +from llmfoundry.layers_registry import fcs, norms from llmfoundry.utils.registry_utils import create_registry _loggers_description = ( @@ -121,4 +121,5 @@ 'metrics', 'dataloaders', 'norms', + 'fcs', ] diff --git a/tests/test_registry.py b/tests/test_registry.py index c93c7c9749..6fa0f16b49 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -31,6 +31,7 @@ def test_expected_registries_exist(): 'metrics', 'models', 'norms', + 'fcs', } assert existing_registries == expected_registry_names From 135b341a18292855779ee11df0b46b103d645757 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 15:50:04 -0700 Subject: [PATCH 4/4] pc --- llmfoundry/models/layers/ffn.py | 2 +- llmfoundry/models/utils/param_init_fns.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index e8601f5212..f0b499875a 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -12,8 +12,8 @@ import torch.nn as nn from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Shard -from llmfoundry.models.layers.layer_builders import build_fc from llmfoundry.models.layers.dmoe import dMoE +from llmfoundry.models.layers.layer_builders import build_fc try: import transformer_engine.pytorch as te diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index 807445f225..bd409dee36 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -13,7 +13,6 @@ from torch.distributed._tensor import DTensor from llmfoundry.layers_registry import fcs, norms -from llmfoundry.layers_registry import norms from llmfoundry.models.layers.dmoe import GLU, MLP try: