From b81897ac0a3ee9eb58847e0a8645a81ce11c280a Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Thu, 4 Apr 2024 18:18:59 -0700 Subject: [PATCH] Norms registry (#1080) --- README.md | 2 +- llmfoundry/layers_registry.py | 20 ++++++++++++++++ llmfoundry/models/layers/__init__.py | 3 +-- llmfoundry/models/layers/attention.py | 15 ++++++++---- llmfoundry/models/layers/blocks.py | 15 ++++++++---- llmfoundry/models/layers/layer_builders.py | 25 ++++++++++++++++++++ llmfoundry/models/layers/norm.py | 19 +++++++-------- llmfoundry/models/mpt/configuration_mpt.py | 3 +++ llmfoundry/models/mpt/modeling_mpt.py | 14 +++++++---- llmfoundry/models/utils/act_ckpt.py | 8 +++---- llmfoundry/models/utils/param_init_fns.py | 5 ++-- llmfoundry/registry.py | 2 ++ llmfoundry/utils/registry_utils.py | 7 ++++++ tests/models/test_model.py | 5 ++-- tests/models/test_rmsnorm_triton_vs_eager.py | 20 +++++++++------- tests/test_registry.py | 1 + 16 files changed, 121 insertions(+), 43 deletions(-) create mode 100644 llmfoundry/layers_registry.py create mode 100644 llmfoundry/models/layers/layer_builders.py diff --git a/README.md b/README.md index 5ed5bd3ee9..ef2a754658 100644 --- a/README.md +++ b/README.md @@ -306,7 +306,7 @@ dependencies = [ "llm-foundry", ] -[project.entry-points."llm_foundry.loggers"] +[project.entry-points."llmfoundry_loggers"] my_logger = "foundry_registry.loggers:MyLogger" ``` diff --git a/llmfoundry/layers_registry.py b/llmfoundry/layers_registry.py new file mode 100644 index 0000000000..9c7dabe128 --- /dev/null +++ b/llmfoundry/layers_registry.py @@ -0,0 +1,20 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Type + +import torch + +from llmfoundry.utils.registry_utils import create_registry + +# Layers +_norm_description = """The norms registry is used to register classes that implement normalization layers.""" +norms = create_registry('llmfoundry', + 'norms', + generic_type=Type[torch.nn.Module], + entry_points=True, + description=_norm_description) + +__all__ = [ + 'norms', +] diff --git a/llmfoundry/models/layers/__init__.py b/llmfoundry/models/layers/__init__.py index df4216b81c..262f190b47 100644 --- a/llmfoundry/models/layers/__init__.py +++ b/llmfoundry/models/layers/__init__.py @@ -9,7 +9,7 @@ from llmfoundry.models.layers.custom_embedding import SharedEmbedding from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn -from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY, LPLayerNorm +from llmfoundry.models.layers.norm import LPLayerNorm __all__ = [ 'scaled_multihead_dot_product_attention', @@ -23,7 +23,6 @@ 'ATTN_CLASS_REGISTRY', 'MPTMLP', 'MPTBlock', - 'NORM_CLASS_REGISTRY', 'LPLayerNorm', 'FC_CLASS_REGISTRY', 'SharedEmbedding', diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index e3ba488f3f..c24b3d4afa 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -15,7 +15,7 @@ from torch import nn from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY -from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY +from llmfoundry.models.layers.layer_builders import build_norm def is_flash_v2_installed(v2_version: str = '2.0.0'): @@ -419,12 +419,19 @@ def __init__( self.Wqkv._fused = (0, fuse_splits) if self.qk_ln or self.qk_gn: - norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] norm_size = self.head_dim if qk_gn else d_model - self.q_ln = norm_class(norm_size, device=device) + self.q_ln = build_norm( + name=norm_type.lower(), + normalized_shape=norm_size, + device=device, + ) if qk_ln: norm_size = self.head_dim * kv_n_heads - self.k_ln = norm_class(norm_size, device=device) + self.k_ln = build_norm( + name=norm_type.lower(), + normalized_shape=norm_size, + device=device, + ) if self.attn_impl == 'flash': self.attn_fn = flash_attn_fn diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 855df7903f..42feb983d4 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -10,7 +10,7 @@ from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn -from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY +from llmfoundry.models.layers.layer_builders import build_norm try: from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip @@ -72,7 +72,6 @@ def __init__( del kwargs # unused, just to capture any extra args from the config super().__init__() - norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] assert isinstance(attn_config['attn_type'], str) attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] @@ -88,7 +87,11 @@ def __init__( if k not in args_to_exclude_in_attn_class } - self.norm_1 = norm_class(d_model, device=device) + self.norm_1 = build_norm( + name=norm_type.lower(), + normalized_shape=d_model, + device=device, + ) self.attn = attn_class( d_model=d_model, n_heads=n_heads, @@ -100,7 +103,11 @@ def __init__( self.norm_2 = None if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm', False): - self.norm_2 = norm_class(d_model, device=device) + self.norm_2 = build_norm( + name=norm_type.lower(), + normalized_shape=d_model, + device=device, + ) self.ffn = build_ffn( d_model=d_model, expansion_ratio=expansion_ratio, diff --git a/llmfoundry/models/layers/layer_builders.py b/llmfoundry/models/layers/layer_builders.py new file mode 100644 index 0000000000..23f5b89668 --- /dev/null +++ b/llmfoundry/models/layers/layer_builders.py @@ -0,0 +1,25 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional, Union + +import torch + +from llmfoundry.layers_registry import norms +from llmfoundry.utils.registry_utils import construct_from_registry + + +def build_norm( + name: str, + normalized_shape: Union[int, List[int], torch.Size], + device: Optional[str] = None, +): + kwargs = { + 'normalized_shape': normalized_shape, + 'device': device, + } + + return construct_from_registry(name=name, + registry=norms, + pre_validation_function=torch.nn.Module, + kwargs=kwargs) diff --git a/llmfoundry/models/layers/norm.py b/llmfoundry/models/layers/norm.py index be4f50f521..92d295c71c 100644 --- a/llmfoundry/models/layers/norm.py +++ b/llmfoundry/models/layers/norm.py @@ -1,10 +1,14 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Optional, Type, Union +from typing import List, Optional, Union import torch +from llmfoundry.layers_registry import norms + +norms.register(name='layernorm', func=torch.nn.LayerNorm) + def _cast_if_autocast_enabled(tensor: torch.Tensor) -> torch.Tensor: if torch.is_autocast_enabled(): @@ -18,6 +22,7 @@ def _cast_if_autocast_enabled(tensor: torch.Tensor) -> torch.Tensor: return tensor +@norms.register_class('low_precision_layernorm') class LPLayerNorm(torch.nn.LayerNorm): def __init__( @@ -62,6 +67,7 @@ def rms_norm(x: torch.Tensor, return output +@norms.register_class('rmsnorm') class RMSNorm(torch.nn.Module): def __init__( @@ -84,6 +90,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) +@norms.register_class('low_precision_rmsnorm') class LPRMSNorm(RMSNorm): def __init__( @@ -111,6 +118,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.eps).to(dtype=x.dtype) +@norms.register_class('triton_rmsnorm') class TritonRMSNorm(torch.nn.Module): def __init__( @@ -150,12 +158,3 @@ def forward(self, x: torch.Tensor): prenorm=False, residual_in_fp32=False, ) - - -NORM_CLASS_REGISTRY: Dict[str, Type[torch.nn.Module]] = { - 'layernorm': torch.nn.LayerNorm, - 'low_precision_layernorm': LPLayerNorm, - 'rmsnorm': RMSNorm, - 'low_precision_rmsnorm': LPRMSNorm, - 'triton_rmsnorm': TritonRMSNorm, -} diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 20c3850a82..2f58ea312e 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -19,6 +19,9 @@ from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY # type: ignore (see note) from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note) from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note) +from llmfoundry.models.layers.layer_builders import build_norm # type: ignore (see note) +from llmfoundry.layers_registry import norms # type: ignore (see note) +from llmfoundry.utils.registry_utils import construct_from_registry # type: ignore (see note) ffn_config_defaults: Dict = { 'ffn_type': 'mptmlp', diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 183e1b24f6..d54b797269 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -20,7 +20,6 @@ from composer.utils import dist from llmfoundry.models.layers.attention import is_flash_v2_installed -from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY if is_flash_v2_installed(): try: # This try...except is needed because transformers requires it despite the 'if' statement above @@ -42,11 +41,13 @@ from transformers.models.llama.modeling_llama import \ LlamaRotaryEmbedding as HFRotaryEmbedding +from llmfoundry.layers_registry import norms from llmfoundry.models.layers.attention import (attn_bias_shape, build_attn_bias, gen_slopes) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding from llmfoundry.models.layers.ffn import build_ffn as build_ffn +from llmfoundry.models.layers.layer_builders import build_norm from llmfoundry.models.mpt.configuration_mpt import MPTConfig # NOTE: All utils are imported directly even if unused so that @@ -297,12 +298,11 @@ def __init__(self, config: MPTConfig): else: config.init_device = 'meta' - if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys(): - norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys()) + if config.norm_type.lower() not in norms.get_all(): + norm_options = ' | '.join(norms.get_all()) raise NotImplementedError( f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).' ) - norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()] # CogView (https://arxiv.org/abs/2105.13290) and GLM-130B (https://arxiv.org/abs/2210.02414) # both report this helping with stabilizing training @@ -329,7 +329,11 @@ def __init__(self, config: MPTConfig): block.max_block_idx = config.n_layers - 1 pass_on_block_idx(block) - self.norm_f = norm_class(config.d_model, device=config.init_device) + self.norm_f = build_norm( + name=config.norm_type.lower(), + normalized_shape=config.d_model, + device=config.init_device, + ) self.rope = config.attn_config['rope'] self.rope_impl = None diff --git a/llmfoundry/models/utils/act_ckpt.py b/llmfoundry/models/utils/act_ckpt.py index 9acd7dd11c..bde7c92bd7 100644 --- a/llmfoundry/models/utils/act_ckpt.py +++ b/llmfoundry/models/utils/act_ckpt.py @@ -5,10 +5,10 @@ import torch +from llmfoundry.layers_registry import norms from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY -from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY def pass_on_block_idx(parent: torch.nn.Module): @@ -29,12 +29,12 @@ def get_act_ckpt_module(mod_name: str) -> Any: mod_type = ATTN_CLASS_REGISTRY[mod_name] elif mod_name in FFN_CLASS_REGISTRY: mod_type = FFN_CLASS_REGISTRY[mod_name] - elif mod_name in NORM_CLASS_REGISTRY: - mod_type = NORM_CLASS_REGISTRY[mod_name] + elif mod_name in norms: + mod_type = norms.get(mod_name) else: msg = ', '.join( list(ATTN_CLASS_REGISTRY.keys()) + list(FFN_CLASS_REGISTRY.keys()) + - list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock']) + list(norms.get_all()) + ['MPTBlock']) raise ValueError( f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.' ) diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index 2e72ccfa47..35dc88a408 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -10,8 +10,8 @@ import torch from torch import nn +from llmfoundry.layers_registry import norms from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY -from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY try: import transformer_engine.pytorch as te @@ -129,7 +129,8 @@ def generic_param_init_fn_( emb_init_fn_(module.weight) - elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))): + elif isinstance(module, + tuple(set([norms.get(name) for name in norms.get_all()]))): # Norm if hasattr(module, 'weight') and isinstance(module.weight, torch.Tensor): diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index e289a923b6..424075da3b 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -12,6 +12,7 @@ from transformers import PreTrainedTokenizerBase from llmfoundry.interfaces import CallbackWithConfig +from llmfoundry.layers_registry import norms from llmfoundry.utils.registry_utils import create_registry _loggers_description = ( @@ -119,4 +120,5 @@ 'models', 'metrics', 'dataloaders', + 'norms', ] diff --git a/llmfoundry/utils/registry_utils.py b/llmfoundry/utils/registry_utils.py index 7089996a13..0901ea198a 100644 --- a/llmfoundry/utils/registry_utils.py +++ b/llmfoundry/utils/registry_utils.py @@ -14,6 +14,7 @@ __all__ = ['TypedRegistry', 'create_registry', 'construct_from_registry'] T = TypeVar('T') +TypeBoundT = TypeVar('TypeBoundT', bound=Type) class TypedRegistry(catalogue.Registry, Generic[T]): @@ -36,6 +37,12 @@ def __call__(self, name: str, func: Optional[T] = None) -> Callable[[T], T]: def register(self, name: str, *, func: Optional[T] = None) -> T: return super().register(name, func=func) + def register_class(self, + name: str, + *, + func: Optional[TypeBoundT] = None) -> TypeBoundT: + return super().register(name, func=func) + def get(self, name: str) -> T: return super().get(name) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index c5f6062b0e..7bd8292151 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -26,8 +26,9 @@ from transformers.models.bloom.modeling_bloom import build_alibi_tensor from llmfoundry import ComposerHFCausalLM +from llmfoundry.layers_registry import norms from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP -from llmfoundry.models.layers import NORM_CLASS_REGISTRY, build_alibi_bias +from llmfoundry.models.layers import build_alibi_bias from llmfoundry.models.layers.attention import (check_alibi_support, is_flash_v2_installed) from llmfoundry.models.layers.blocks import MPTBlock @@ -682,7 +683,7 @@ def test_lora_id(): assert isinstance(model.model, peft.PeftModelForCausalLM) -@pytest.mark.parametrize('norm_type', NORM_CLASS_REGISTRY.keys()) +@pytest.mark.parametrize('norm_type', norms.get_all()) @pytest.mark.parametrize('no_bias', [False, True]) @pytest.mark.parametrize('tie_word_embeddings', [True, False]) @pytest.mark.parametrize('expansion_ratio,ffn_hidden_size', [ diff --git a/tests/models/test_rmsnorm_triton_vs_eager.py b/tests/models/test_rmsnorm_triton_vs_eager.py index 1902f46d78..7169c5d926 100644 --- a/tests/models/test_rmsnorm_triton_vs_eager.py +++ b/tests/models/test_rmsnorm_triton_vs_eager.py @@ -8,6 +8,7 @@ from composer.core.precision import get_precision_context from llmfoundry.models.layers.attention import is_flash_v2_installed +from llmfoundry.models.layers.layer_builders import build_norm @pytest.mark.gpu @@ -19,17 +20,18 @@ def test_rmsnorm_triton_vs_eager(normalized_shape: Union[int, List[int]], pytest.skip( 'triton implementation of rmsnorm requires flash attention 2.') - from llmfoundry.models.layers import norm - batch_size = 2 - cfg = { - 'normalized_shape': normalized_shape, - 'device': device, - } - - eager_rmsnorm = norm.NORM_CLASS_REGISTRY['rmsnorm'](**cfg) - triton_rmsnorm = norm.NORM_CLASS_REGISTRY['triton_rmsnorm'](**cfg) + eager_rmsnorm = build_norm( + name='rmsnorm', + normalized_shape=normalized_shape, + device=device, + ) + triton_rmsnorm = build_norm( + name='triton_rmsnorm', + normalized_shape=normalized_shape, + device=device, + ) triton_rmsnorm.load_state_dict(eager_rmsnorm.state_dict()) diff --git a/tests/test_registry.py b/tests/test_registry.py index 30f6e0e38f..c93c7c9749 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -30,6 +30,7 @@ def test_expected_registries_exist(): 'dataloaders', 'metrics', 'models', + 'norms', } assert existing_registries == expected_registry_names