From 1c7b82cac006f6636c21cca7f89aa26777b0748e Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 2 Apr 2024 16:22:07 -0700 Subject: [PATCH] attention registry --- llmfoundry/layers_registry.py | 20 ++++++++++++++++++-- llmfoundry/models/layers/__init__.py | 7 +++---- llmfoundry/models/layers/attention.py | 19 ++++++++----------- llmfoundry/models/layers/blocks.py | 22 ++++++++++++---------- llmfoundry/models/layers/layer_builders.py | 14 ++++++++++++-- llmfoundry/models/utils/act_ckpt.py | 9 ++++----- llmfoundry/utils/registry_utils.py | 1 + tests/models/layers/test_flash_torch.py | 11 +++++++++-- tests/models/test_rope_dail_vs_hf.py | 13 +++++++++---- 9 files changed, 76 insertions(+), 40 deletions(-) diff --git a/llmfoundry/layers_registry.py b/llmfoundry/layers_registry.py index 9c7dabe128..f683c245da 100644 --- a/llmfoundry/layers_registry.py +++ b/llmfoundry/layers_registry.py @@ -1,13 +1,12 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import Type +from typing import Callable, Type import torch from llmfoundry.utils.registry_utils import create_registry -# Layers _norm_description = """The norms registry is used to register classes that implement normalization layers.""" norms = create_registry('llmfoundry', 'norms', @@ -15,6 +14,23 @@ entry_points=True, description=_norm_description) +_attention_class_description = """The attention_class registry is used to register classes that implement attention layers.""" +attention_class = create_registry('llmfoundry', + 'attention_class', + generic_type=Type[torch.nn.Module], + entry_points=True, + description=_attention_class_description) + +_attention_implementation_description = """The attention_implementation registry is used to register functions that implement the attention operation.""" +attention_implementation = create_registry( + 'llmfoundry', + 'attention_implementation', + generic_type=Callable, + entry_points=True, + description=_attention_implementation_description) + __all__ = [ 'norms', + 'attention_class', + 'attention_implementation', ] diff --git a/llmfoundry/models/layers/__init__.py b/llmfoundry/models/layers/__init__.py index 262f190b47..023b9edb8f 100644 --- a/llmfoundry/models/layers/__init__.py +++ b/llmfoundry/models/layers/__init__.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 from llmfoundry.models.layers.attention import ( - ATTN_CLASS_REGISTRY, GroupedQueryAttention, MultiheadAttention, - MultiQueryAttention, attn_bias_shape, build_alibi_bias, build_attn_bias, - flash_attn_fn, scaled_multihead_dot_product_attention) + GroupedQueryAttention, MultiheadAttention, MultiQueryAttention, + attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn, + scaled_multihead_dot_product_attention) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY @@ -20,7 +20,6 @@ 'attn_bias_shape', 'build_attn_bias', 'build_alibi_bias', - 'ATTN_CLASS_REGISTRY', 'MPTMLP', 'MPTBlock', 'LPLayerNorm', diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index c24b3d4afa..1b8f0c2e1d 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -14,6 +14,7 @@ from packaging import version from torch import nn +from llmfoundry.layers_registry import attention_class, attention_implementation from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY from llmfoundry.models.layers.layer_builders import build_norm @@ -341,6 +342,7 @@ def flash_attn_fn( return output, None, past_key_value +@attention_class.register_class('grouped_query_attention') class GroupedQueryAttention(nn.Module): """Grouped Query Attention (GQA) is a generalization of Multi-head (MHA). @@ -433,12 +435,7 @@ def __init__( device=device, ) - if self.attn_impl == 'flash': - self.attn_fn = flash_attn_fn - elif self.attn_impl == 'torch': - self.attn_fn = scaled_multihead_dot_product_attention - else: - raise ValueError(f'{attn_impl=} is an invalid setting.') + self.attn_fn = attention_implementation.get(self.attn_impl) self.out_proj = FC_CLASS_REGISTRY[fc_type]( self.d_model, @@ -572,6 +569,7 @@ def forward( return self.out_proj(context), attn_weights, past_key_value +@attention_class.register_class('multihead_attention') class MultiheadAttention(GroupedQueryAttention): """Multi-head self attention. @@ -612,6 +610,7 @@ def __init__( ) +@attention_class.register_class('multiquery_attention') class MultiQueryAttention(GroupedQueryAttention): """Multi-Query self attention. @@ -740,8 +739,6 @@ def build_alibi_bias( return alibi_bias.to(dtype=dtype) -ATTN_CLASS_REGISTRY = { - 'multihead_attention': MultiheadAttention, - 'multiquery_attention': MultiQueryAttention, - 'grouped_query_attention': GroupedQueryAttention -} +attention_implementation.register('flash', func=flash_attn_fn) +attention_implementation.register('torch', + func=scaled_multihead_dot_product_attention) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 42feb983d4..c6a7f7dbf8 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -8,9 +8,9 @@ import torch import torch.nn as nn -from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn -from llmfoundry.models.layers.layer_builders import build_norm +from llmfoundry.models.layers.layer_builders import (build_attention_layer, + build_norm) try: from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip @@ -73,7 +73,6 @@ def __init__( super().__init__() assert isinstance(attn_config['attn_type'], str) - attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] # necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs args_to_exclude_in_attn_class = { @@ -92,13 +91,16 @@ def __init__( normalized_shape=d_model, device=device, ) - self.attn = attn_class( - d_model=d_model, - n_heads=n_heads, - fc_type=fc_type, - device=device, - **attn_config_subset_for_attn_class, - bias=not no_bias, + self.attn = build_attention_layer( + name=attn_config['attn_type'], + attn_kwargs={ + 'd_model': d_model, + 'n_heads': n_heads, + 'fc_type': fc_type, + 'device': device, + 'bias': not no_bias, + **attn_config_subset_for_attn_class + }, ) self.norm_2 = None if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm', diff --git a/llmfoundry/models/layers/layer_builders.py b/llmfoundry/models/layers/layer_builders.py index 23f5b89668..e5589f8d02 100644 --- a/llmfoundry/models/layers/layer_builders.py +++ b/llmfoundry/models/layers/layer_builders.py @@ -1,11 +1,11 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch -from llmfoundry.layers_registry import norms +from llmfoundry.layers_registry import attention_class, norms from llmfoundry.utils.registry_utils import construct_from_registry @@ -23,3 +23,13 @@ def build_norm( registry=norms, pre_validation_function=torch.nn.Module, kwargs=kwargs) + + +def build_attention_layer( + name: str, + attn_kwargs: Dict[str, Any], +): + return construct_from_registry(name=name, + registry=attention_class, + pre_validation_function=torch.nn.Module, + kwargs=attn_kwargs) diff --git a/llmfoundry/models/utils/act_ckpt.py b/llmfoundry/models/utils/act_ckpt.py index bde7c92bd7..3399e70e03 100644 --- a/llmfoundry/models/utils/act_ckpt.py +++ b/llmfoundry/models/utils/act_ckpt.py @@ -5,8 +5,7 @@ import torch -from llmfoundry.layers_registry import norms -from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY +from llmfoundry.layers_registry import attention_class, norms from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY @@ -25,15 +24,15 @@ def get_act_ckpt_module(mod_name: str) -> Any: """Get the module type from the module name.""" if mod_name.lower() == 'mptblock': mod_type = MPTBlock - elif mod_name in ATTN_CLASS_REGISTRY: - mod_type = ATTN_CLASS_REGISTRY[mod_name] + elif mod_name in attention_class.get_all(): + mod_type = attention_class.get(mod_name) elif mod_name in FFN_CLASS_REGISTRY: mod_type = FFN_CLASS_REGISTRY[mod_name] elif mod_name in norms: mod_type = norms.get(mod_name) else: msg = ', '.join( - list(ATTN_CLASS_REGISTRY.keys()) + list(FFN_CLASS_REGISTRY.keys()) + + list(attention_class.get_all()) + list(FFN_CLASS_REGISTRY.keys()) + list(norms.get_all()) + ['MPTBlock']) raise ValueError( f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.' diff --git a/llmfoundry/utils/registry_utils.py b/llmfoundry/utils/registry_utils.py index 0901ea198a..d9c23e6f26 100644 --- a/llmfoundry/utils/registry_utils.py +++ b/llmfoundry/utils/registry_utils.py @@ -15,6 +15,7 @@ T = TypeVar('T') TypeBoundT = TypeVar('TypeBoundT', bound=Type) +CallableBoundT = TypeVar('CallableBoundT', bound=Callable[..., Any]) class TypedRegistry(catalogue.Registry, Generic[T]): diff --git a/tests/models/layers/test_flash_torch.py b/tests/models/layers/test_flash_torch.py index c0e9f4b3b5..f212665c93 100644 --- a/tests/models/layers/test_flash_torch.py +++ b/tests/models/layers/test_flash_torch.py @@ -8,6 +8,7 @@ from llmfoundry.models.layers import attention from llmfoundry.models.layers.attention import (check_alibi_support, gen_slopes, is_flash_v2_installed) +from llmfoundry.models.layers.layer_builders import build_attention_layer from llmfoundry.models.mpt.modeling_mpt import (apply_sequence_id, gen_attention_mask_in_length, gen_flash_attn_padding_info, @@ -120,9 +121,15 @@ def test_attn_impl(attn_impl_0: str, ]).to(device=device) cfg.attn_impl = attn_impl_0 - attn0 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device) + attn0 = build_attention_layer( + name=attn_type, + attn_kwargs=om.to_container(cfg), # type: ignore + ).to(device) cfg.attn_impl = attn_impl_1 - attn1 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device) + attn1 = build_attention_layer( + name=attn_type, + attn_kwargs=om.to_container(cfg), # type: ignore + ).to(device) attn1.load_state_dict(attn0.state_dict()) diff --git a/tests/models/test_rope_dail_vs_hf.py b/tests/models/test_rope_dail_vs_hf.py index 33c3d3c052..c591aecc73 100644 --- a/tests/models/test_rope_dail_vs_hf.py +++ b/tests/models/test_rope_dail_vs_hf.py @@ -7,6 +7,7 @@ from omegaconf import OmegaConf as om from llmfoundry.models.layers.attention import is_flash_v2_installed +from llmfoundry.models.layers.layer_builders import build_attention_layer from llmfoundry.models.mpt.modeling_mpt import (gen_flash_attn_padding_info, gen_rotary_embedding) @@ -21,8 +22,6 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): if not is_flash_v2_installed(): pytest.skip('dail implementation of rope requires flash attention 2.') - from llmfoundry.models.layers import attention - cfg = om.create({ 'attn_impl': 'flash', 'd_model': 128, @@ -37,8 +36,14 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): if attn_type == 'grouped_query_attention': cfg.kv_n_heads = 2 - attn0 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device) - attn1 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device) + attn0 = build_attention_layer( + name=attn_type, + attn_kwargs=om.to_container(cfg), # type: ignore + ).to(device) + attn1 = build_attention_layer( + name=attn_type, + attn_kwargs=om.to_container(cfg), # type: ignore + ).to(device) attn1.load_state_dict(attn0.state_dict()) x0 = torch.randn(batch_size, seq_len, cfg.d_model).to(device)