Skip to content

Commit

Permalink
attention registry
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Apr 4, 2024
1 parent 827b9a1 commit 1c7b82c
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 40 deletions.
20 changes: 18 additions & 2 deletions llmfoundry/layers_registry.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,36 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Type
from typing import Callable, Type

import torch

from llmfoundry.utils.registry_utils import create_registry

# Layers
_norm_description = """The norms registry is used to register classes that implement normalization layers."""
norms = create_registry('llmfoundry',
'norms',
generic_type=Type[torch.nn.Module],
entry_points=True,
description=_norm_description)

_attention_class_description = """The attention_class registry is used to register classes that implement attention layers."""
attention_class = create_registry('llmfoundry',
'attention_class',
generic_type=Type[torch.nn.Module],
entry_points=True,
description=_attention_class_description)

_attention_implementation_description = """The attention_implementation registry is used to register functions that implement the attention operation."""
attention_implementation = create_registry(
'llmfoundry',
'attention_implementation',
generic_type=Callable,
entry_points=True,
description=_attention_implementation_description)

__all__ = [
'norms',
'attention_class',
'attention_implementation',
]
7 changes: 3 additions & 4 deletions llmfoundry/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.models.layers.attention import (
ATTN_CLASS_REGISTRY, GroupedQueryAttention, MultiheadAttention,
MultiQueryAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
flash_attn_fn, scaled_multihead_dot_product_attention)
GroupedQueryAttention, MultiheadAttention, MultiQueryAttention,
attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn,
scaled_multihead_dot_product_attention)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
Expand All @@ -20,7 +20,6 @@
'attn_bias_shape',
'build_attn_bias',
'build_alibi_bias',
'ATTN_CLASS_REGISTRY',
'MPTMLP',
'MPTBlock',
'LPLayerNorm',
Expand Down
19 changes: 8 additions & 11 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from packaging import version
from torch import nn

from llmfoundry.layers_registry import attention_class, attention_implementation
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.layer_builders import build_norm

Expand Down Expand Up @@ -341,6 +342,7 @@ def flash_attn_fn(
return output, None, past_key_value


@attention_class.register_class('grouped_query_attention')
class GroupedQueryAttention(nn.Module):
"""Grouped Query Attention (GQA) is a generalization of Multi-head (MHA).
Expand Down Expand Up @@ -433,12 +435,7 @@ def __init__(
device=device,
)

if self.attn_impl == 'flash':
self.attn_fn = flash_attn_fn
elif self.attn_impl == 'torch':
self.attn_fn = scaled_multihead_dot_product_attention
else:
raise ValueError(f'{attn_impl=} is an invalid setting.')
self.attn_fn = attention_implementation.get(self.attn_impl)

self.out_proj = FC_CLASS_REGISTRY[fc_type](
self.d_model,
Expand Down Expand Up @@ -572,6 +569,7 @@ def forward(
return self.out_proj(context), attn_weights, past_key_value


@attention_class.register_class('multihead_attention')
class MultiheadAttention(GroupedQueryAttention):
"""Multi-head self attention.
Expand Down Expand Up @@ -612,6 +610,7 @@ def __init__(
)


@attention_class.register_class('multiquery_attention')
class MultiQueryAttention(GroupedQueryAttention):
"""Multi-Query self attention.
Expand Down Expand Up @@ -740,8 +739,6 @@ def build_alibi_bias(
return alibi_bias.to(dtype=dtype)


ATTN_CLASS_REGISTRY = {
'multihead_attention': MultiheadAttention,
'multiquery_attention': MultiQueryAttention,
'grouped_query_attention': GroupedQueryAttention
}
attention_implementation.register('flash', func=flash_attn_fn)
attention_implementation.register('torch',
func=scaled_multihead_dot_product_attention)
22 changes: 12 additions & 10 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import torch
import torch.nn as nn

from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn
from llmfoundry.models.layers.layer_builders import build_norm
from llmfoundry.models.layers.layer_builders import (build_attention_layer,
build_norm)

try:
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
Expand Down Expand Up @@ -73,7 +73,6 @@ def __init__(
super().__init__()

assert isinstance(attn_config['attn_type'], str)
attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]

# necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
args_to_exclude_in_attn_class = {
Expand All @@ -92,13 +91,16 @@ def __init__(
normalized_shape=d_model,
device=device,
)
self.attn = attn_class(
d_model=d_model,
n_heads=n_heads,
fc_type=fc_type,
device=device,
**attn_config_subset_for_attn_class,
bias=not no_bias,
self.attn = build_attention_layer(
name=attn_config['attn_type'],
attn_kwargs={
'd_model': d_model,
'n_heads': n_heads,
'fc_type': fc_type,
'device': device,
'bias': not no_bias,
**attn_config_subset_for_attn_class
},
)
self.norm_2 = None
if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm',
Expand Down
14 changes: 12 additions & 2 deletions llmfoundry/models/layers/layer_builders.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import torch

from llmfoundry.layers_registry import norms
from llmfoundry.layers_registry import attention_class, norms
from llmfoundry.utils.registry_utils import construct_from_registry


Expand All @@ -23,3 +23,13 @@ def build_norm(
registry=norms,
pre_validation_function=torch.nn.Module,
kwargs=kwargs)


def build_attention_layer(
name: str,
attn_kwargs: Dict[str, Any],
):
return construct_from_registry(name=name,
registry=attention_class,
pre_validation_function=torch.nn.Module,
kwargs=attn_kwargs)
9 changes: 4 additions & 5 deletions llmfoundry/models/utils/act_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

import torch

from llmfoundry.layers_registry import norms
from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY
from llmfoundry.layers_registry import attention_class, norms
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY

Expand All @@ -25,15 +24,15 @@ def get_act_ckpt_module(mod_name: str) -> Any:
"""Get the module type from the module name."""
if mod_name.lower() == 'mptblock':
mod_type = MPTBlock
elif mod_name in ATTN_CLASS_REGISTRY:
mod_type = ATTN_CLASS_REGISTRY[mod_name]
elif mod_name in attention_class.get_all():
mod_type = attention_class.get(mod_name)
elif mod_name in FFN_CLASS_REGISTRY:
mod_type = FFN_CLASS_REGISTRY[mod_name]
elif mod_name in norms:
mod_type = norms.get(mod_name)
else:
msg = ', '.join(
list(ATTN_CLASS_REGISTRY.keys()) + list(FFN_CLASS_REGISTRY.keys()) +
list(attention_class.get_all()) + list(FFN_CLASS_REGISTRY.keys()) +
list(norms.get_all()) + ['MPTBlock'])
raise ValueError(
f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.'
Expand Down
1 change: 1 addition & 0 deletions llmfoundry/utils/registry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

T = TypeVar('T')
TypeBoundT = TypeVar('TypeBoundT', bound=Type)
CallableBoundT = TypeVar('CallableBoundT', bound=Callable[..., Any])


class TypedRegistry(catalogue.Registry, Generic[T]):
Expand Down
11 changes: 9 additions & 2 deletions tests/models/layers/test_flash_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from llmfoundry.models.layers import attention
from llmfoundry.models.layers.attention import (check_alibi_support, gen_slopes,
is_flash_v2_installed)
from llmfoundry.models.layers.layer_builders import build_attention_layer
from llmfoundry.models.mpt.modeling_mpt import (apply_sequence_id,
gen_attention_mask_in_length,
gen_flash_attn_padding_info,
Expand Down Expand Up @@ -120,9 +121,15 @@ def test_attn_impl(attn_impl_0: str,
]).to(device=device)

cfg.attn_impl = attn_impl_0
attn0 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device)
attn0 = build_attention_layer(
name=attn_type,
attn_kwargs=om.to_container(cfg), # type: ignore
).to(device)
cfg.attn_impl = attn_impl_1
attn1 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device)
attn1 = build_attention_layer(
name=attn_type,
attn_kwargs=om.to_container(cfg), # type: ignore
).to(device)

attn1.load_state_dict(attn0.state_dict())

Expand Down
13 changes: 9 additions & 4 deletions tests/models/test_rope_dail_vs_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from omegaconf import OmegaConf as om

from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.layers.layer_builders import build_attention_layer
from llmfoundry.models.mpt.modeling_mpt import (gen_flash_attn_padding_info,
gen_rotary_embedding)

Expand All @@ -21,8 +22,6 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'):
if not is_flash_v2_installed():
pytest.skip('dail implementation of rope requires flash attention 2.')

from llmfoundry.models.layers import attention

cfg = om.create({
'attn_impl': 'flash',
'd_model': 128,
Expand All @@ -37,8 +36,14 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'):
if attn_type == 'grouped_query_attention':
cfg.kv_n_heads = 2

attn0 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device)
attn1 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device)
attn0 = build_attention_layer(
name=attn_type,
attn_kwargs=om.to_container(cfg), # type: ignore
).to(device)
attn1 = build_attention_layer(
name=attn_type,
attn_kwargs=om.to_container(cfg), # type: ignore
).to(device)

attn1.load_state_dict(attn0.state_dict())
x0 = torch.randn(batch_size, seq_len, cfg.d_model).to(device)
Expand Down

0 comments on commit 1c7b82c

Please sign in to comment.