From 560012b6282bdc44ca35e33d9c2439f71fbddfee Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Thu, 11 Apr 2024 18:54:49 -0700 Subject: [PATCH] Attention layer registry (#1094) --- llmfoundry/layers_registry.py | 28 +++++++++++++-- llmfoundry/models/layers/__init__.py | 7 ++-- llmfoundry/models/layers/attention.py | 20 +++++------ llmfoundry/models/layers/blocks.py | 41 ++++++++++++---------- llmfoundry/models/layers/layer_builders.py | 12 ++++++- llmfoundry/models/utils/act_ckpt.py | 12 +++---- llmfoundry/registry.py | 26 ++++++++++---- llmfoundry/utils/registry_utils.py | 1 + tests/models/layers/test_flash_torch.py | 11 ++++-- tests/models/test_rope_dail_vs_hf.py | 15 +++++--- tests/test_registry.py | 2 ++ 11 files changed, 118 insertions(+), 57 deletions(-) diff --git a/llmfoundry/layers_registry.py b/llmfoundry/layers_registry.py index 9686e76af7..3c0a7ebd6e 100644 --- a/llmfoundry/layers_registry.py +++ b/llmfoundry/layers_registry.py @@ -1,14 +1,15 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import Type +from typing import Callable, Type import torch from llmfoundry.utils.registry_utils import create_registry -# Layers -_norm_description = """The norms registry is used to register classes that implement normalization layers.""" +_norm_description = ( + 'The norms registry is used to register classes that implement normalization layers.' +) norms = create_registry('llmfoundry', 'norms', generic_type=Type[torch.nn.Module], @@ -25,7 +26,28 @@ entry_points=True, description=_fc_description) +_attention_classes_description = ( + 'The attention_classes registry is used to register classes that implement attention layers. See ' + + 'attention.py for expected constructor signature.') +attention_classes = create_registry('llmfoundry', + 'attention_classes', + generic_type=Type[torch.nn.Module], + entry_points=True, + description=_attention_classes_description) + +_attention_implementations_description = ( + 'The attention_implementations registry is used to register functions that implement the attention operation.' + + 'See attention.py for expected function signature.') +attention_implementations = create_registry( + 'llmfoundry', + 'attention_implementations', + generic_type=Callable, + entry_points=True, + description=_attention_implementations_description) + __all__ = [ 'norms', + 'attention_classes', + 'attention_implementations', 'fcs', ] diff --git a/llmfoundry/models/layers/__init__.py b/llmfoundry/models/layers/__init__.py index 7328a00757..5784fcd7e9 100644 --- a/llmfoundry/models/layers/__init__.py +++ b/llmfoundry/models/layers/__init__.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 from llmfoundry.models.layers.attention import ( - ATTN_CLASS_REGISTRY, GroupedQueryAttention, MultiheadAttention, - MultiQueryAttention, attn_bias_shape, build_alibi_bias, build_attn_bias, - flash_attn_fn, scaled_multihead_dot_product_attention) + GroupedQueryAttention, MultiheadAttention, MultiQueryAttention, + attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn, + scaled_multihead_dot_product_attention) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding from llmfoundry.models.layers.fc import * @@ -20,7 +20,6 @@ 'attn_bias_shape', 'build_attn_bias', 'build_alibi_bias', - 'ATTN_CLASS_REGISTRY', 'MPTMLP', 'MPTBlock', 'LPLayerNorm', diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 331c037ce0..6614d5d161 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -14,6 +14,8 @@ from packaging import version from torch import nn +from llmfoundry.layers_registry import (attention_classes, + attention_implementations) from llmfoundry.models.layers.layer_builders import build_fc, build_norm @@ -340,6 +342,7 @@ def flash_attn_fn( return output, None, past_key_value +@attention_classes.register_class('grouped_query_attention') class GroupedQueryAttention(nn.Module): """Grouped Query Attention (GQA) is a generalization of Multi-head (MHA). @@ -433,12 +436,7 @@ def __init__( device=device, ) - if self.attn_impl == 'flash': - self.attn_fn = flash_attn_fn - elif self.attn_impl == 'torch': - self.attn_fn = scaled_multihead_dot_product_attention - else: - raise ValueError(f'{attn_impl=} is an invalid setting.') + self.attn_fn = attention_implementations.get(self.attn_impl) self.out_proj = build_fc( name=fc_type, @@ -573,6 +571,7 @@ def forward( return self.out_proj(context), attn_weights, past_key_value +@attention_classes.register_class('multihead_attention') class MultiheadAttention(GroupedQueryAttention): """Multi-head self attention. @@ -613,6 +612,7 @@ def __init__( ) +@attention_classes.register_class('multiquery_attention') class MultiQueryAttention(GroupedQueryAttention): """Multi-Query self attention. @@ -741,8 +741,6 @@ def build_alibi_bias( return alibi_bias.to(dtype=dtype) -ATTN_CLASS_REGISTRY = { - 'multihead_attention': MultiheadAttention, - 'multiquery_attention': MultiQueryAttention, - 'grouped_query_attention': GroupedQueryAttention -} +attention_implementations.register('flash', func=flash_attn_fn) +attention_implementations.register('torch', + func=scaled_multihead_dot_product_attention) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 18b9f979f4..1ad9ec954f 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -8,9 +8,9 @@ import torch import torch.nn as nn -from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn -from llmfoundry.models.layers.layer_builders import build_norm +from llmfoundry.models.layers.layer_builders import (build_attention_layer, + build_norm) try: from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip @@ -87,8 +87,6 @@ def __init__( ) else: assert isinstance(attn_config['attn_type'], str) - attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] - # Necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs args_to_exclude_in_attn_class = { 'attn_type', 'alibi', 'attn_uses_sequence_id', 'alibi_bias_max', @@ -106,13 +104,16 @@ def __init__( normalized_shape=d_model, device=device, ) - self.attn = attn_class( - d_model=d_model, - n_heads=n_heads, - fc_type=fc_type, - device=device, - **attn_config_subset_for_attn_class, - bias=not no_bias, + self.attn = build_attention_layer( + name=attn_config['attn_type'], + attn_kwargs={ + 'd_model': d_model, + 'n_heads': n_heads, + 'fc_type': fc_type, + 'device': device, + 'bias': not no_bias, + **attn_config_subset_for_attn_class + }, ) self.norm_2 = None if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], @@ -209,7 +210,6 @@ def __init__( assert attn_config is not None assert ffn_config is not None assert isinstance(attn_config['attn_type'], str) - attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] # necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs args_to_exclude_in_attn_class = { @@ -227,13 +227,16 @@ def __init__( normalized_shape=d_model, device=device, ) - self.attn = attn_class( - d_model=d_model, - n_heads=n_heads, - fc_type=fc_type, - device=device, - **attn_config_subset_for_attn_class, - bias=not no_bias, + self.attn = build_attention_layer( + name=attn_config['attn_type'], + attn_kwargs={ + 'd_model': d_model, + 'n_heads': n_heads, + 'fc_type': fc_type, + 'device': device, + 'bias': not no_bias, + **attn_config_subset_for_attn_class + }, ) self.norm_2 = None if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm', diff --git a/llmfoundry/models/layers/layer_builders.py b/llmfoundry/models/layers/layer_builders.py index 8244089115..6a725d469a 100644 --- a/llmfoundry/models/layers/layer_builders.py +++ b/llmfoundry/models/layers/layer_builders.py @@ -5,7 +5,7 @@ import torch -from llmfoundry.layers_registry import fcs, norms +from llmfoundry.layers_registry import attention_classes, fcs, norms from llmfoundry.utils.registry_utils import construct_from_registry @@ -25,6 +25,16 @@ def build_norm( kwargs=kwargs) +def build_attention_layer( + name: str, + attn_kwargs: Dict[str, Any], +): + return construct_from_registry(name=name, + registry=attention_classes, + pre_validation_function=torch.nn.Module, + kwargs=attn_kwargs) + + def build_fc( name: str, in_features: int, diff --git a/llmfoundry/models/utils/act_ckpt.py b/llmfoundry/models/utils/act_ckpt.py index 1975865f1b..fea68492c1 100644 --- a/llmfoundry/models/utils/act_ckpt.py +++ b/llmfoundry/models/utils/act_ckpt.py @@ -5,8 +5,7 @@ import torch -from llmfoundry.layers_registry import norms -from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY +from llmfoundry.layers_registry import attention_classes, norms from llmfoundry.models.layers.blocks import FusedNormAttentionNorm, MPTBlock from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY @@ -25,18 +24,19 @@ def get_act_ckpt_module(mod_name: str) -> Any: """Get the module type from the module name.""" if mod_name.lower() == 'mptblock': mod_type = MPTBlock + elif mod_name in attention_classes: + mod_type = attention_classes.get(mod_name) elif mod_name.lower() == 'norm_attn_norm': mod_type = FusedNormAttentionNorm - elif mod_name in ATTN_CLASS_REGISTRY: - mod_type = ATTN_CLASS_REGISTRY[mod_name] elif mod_name in FFN_CLASS_REGISTRY: mod_type = FFN_CLASS_REGISTRY[mod_name] elif mod_name in norms: mod_type = norms.get(mod_name) else: msg = ', '.join( - list(ATTN_CLASS_REGISTRY.keys()) + list(FFN_CLASS_REGISTRY.keys()) + - list(norms.get_all()) + ['MPTBlock']) + list(attention_classes.get_all()) + + list(FFN_CLASS_REGISTRY.keys()) + list(norms.get_all()) + + ['MPTBlock']) raise ValueError( f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.' ) diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index 7b605808be..ef6629be10 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -12,7 +12,8 @@ from transformers import PreTrainedTokenizerBase from llmfoundry.interfaces import CallbackWithConfig -from llmfoundry.layers_registry import fcs, norms +from llmfoundry.layers_registry import (attention_classes, + attention_implementations, fcs, norms) from llmfoundry.utils.registry_utils import create_registry _loggers_description = ( @@ -85,17 +86,24 @@ entry_points=True, description=_schedulers_description) -_models_description = """The models registry is used to register classes that implement the ComposerModel interface. The model -constructor should accept two arguments: an omegaconf DictConfig named `om_model_config` and a PreTrainedTokenizerBase named `tokenizer`. -Note: This will soon be updated to take in named kwargs instead of a config directly.""" +_models_description = ( + 'The models registry is used to register classes that implement the ComposerModel interface. ' + + + 'The model constructor should accept two arguments: an omegaconf DictConfig named `om_model_config` and a PreTrainedTokenizerBase named `tokenizer`. ' + + + 'Note: This will soon be updated to take in named kwargs instead of a config directly.' +) models = create_registry('llmfoundry', 'models', generic_type=Type[ComposerModel], entry_points=True, description=_models_description) -_dataloaders_description = """The dataloaders registry is used to register functions that create a DataSpec. The function should take -a DictConfig, a PreTrainedTokenizerBase, and an int as arguments, and return a DataSpec.""" +_dataloaders_description = ( + 'The dataloaders registry is used to register functions that create a DataSpec. The function should take ' + + + 'a DictConfig, a PreTrainedTokenizerBase, and an int as arguments, and return a DataSpec.' +) dataloaders = create_registry( 'llmfoundry', 'dataloaders', @@ -103,7 +111,9 @@ entry_points=True, description=_dataloaders_description) -_metrics_description = """The metrics registry is used to register classes that implement the torchmetrics.Metric interface.""" +_metrics_description = ( + 'The metrics registry is used to register classes that implement the torchmetrics.Metric interface.' +) metrics = create_registry('llmfoundry', 'metrics', generic_type=Type[Metric], @@ -121,5 +131,7 @@ 'metrics', 'dataloaders', 'norms', + 'attention_classes', + 'attention_implementations', 'fcs', ] diff --git a/llmfoundry/utils/registry_utils.py b/llmfoundry/utils/registry_utils.py index 0901ea198a..d9c23e6f26 100644 --- a/llmfoundry/utils/registry_utils.py +++ b/llmfoundry/utils/registry_utils.py @@ -15,6 +15,7 @@ T = TypeVar('T') TypeBoundT = TypeVar('TypeBoundT', bound=Type) +CallableBoundT = TypeVar('CallableBoundT', bound=Callable[..., Any]) class TypedRegistry(catalogue.Registry, Generic[T]): diff --git a/tests/models/layers/test_flash_torch.py b/tests/models/layers/test_flash_torch.py index c0e9f4b3b5..f212665c93 100644 --- a/tests/models/layers/test_flash_torch.py +++ b/tests/models/layers/test_flash_torch.py @@ -8,6 +8,7 @@ from llmfoundry.models.layers import attention from llmfoundry.models.layers.attention import (check_alibi_support, gen_slopes, is_flash_v2_installed) +from llmfoundry.models.layers.layer_builders import build_attention_layer from llmfoundry.models.mpt.modeling_mpt import (apply_sequence_id, gen_attention_mask_in_length, gen_flash_attn_padding_info, @@ -120,9 +121,15 @@ def test_attn_impl(attn_impl_0: str, ]).to(device=device) cfg.attn_impl = attn_impl_0 - attn0 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device) + attn0 = build_attention_layer( + name=attn_type, + attn_kwargs=om.to_container(cfg), # type: ignore + ).to(device) cfg.attn_impl = attn_impl_1 - attn1 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device) + attn1 = build_attention_layer( + name=attn_type, + attn_kwargs=om.to_container(cfg), # type: ignore + ).to(device) attn1.load_state_dict(attn0.state_dict()) diff --git a/tests/models/test_rope_dail_vs_hf.py b/tests/models/test_rope_dail_vs_hf.py index 33c3d3c052..b9ab90357a 100644 --- a/tests/models/test_rope_dail_vs_hf.py +++ b/tests/models/test_rope_dail_vs_hf.py @@ -7,6 +7,7 @@ from omegaconf import OmegaConf as om from llmfoundry.models.layers.attention import is_flash_v2_installed +from llmfoundry.models.layers.layer_builders import build_attention_layer from llmfoundry.models.mpt.modeling_mpt import (gen_flash_attn_padding_info, gen_rotary_embedding) @@ -21,8 +22,6 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): if not is_flash_v2_installed(): pytest.skip('dail implementation of rope requires flash attention 2.') - from llmfoundry.models.layers import attention - cfg = om.create({ 'attn_impl': 'flash', 'd_model': 128, @@ -37,8 +36,16 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): if attn_type == 'grouped_query_attention': cfg.kv_n_heads = 2 - attn0 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device) - attn1 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device) + attn0 = build_attention_layer( + name=attn_type, + attn_kwargs=om.to_container( + cfg), # type: ignore (to_container return broad type) + ).to(device) + attn1 = build_attention_layer( + name=attn_type, + attn_kwargs=om.to_container( + cfg), # type: ignore (to_container return broad type) + ).to(device) attn1.load_state_dict(attn0.state_dict()) x0 = torch.randn(batch_size, seq_len, cfg.d_model).to(device) diff --git a/tests/test_registry.py b/tests/test_registry.py index 6fa0f16b49..29d8e137f3 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -31,6 +31,8 @@ def test_expected_registries_exist(): 'metrics', 'models', 'norms', + 'attention_classes', + 'attention_implementations', 'fcs', }