Skip to content

Commit

Permalink
Attention layer registry (#1094)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Apr 12, 2024
1 parent ed3daef commit 560012b
Show file tree
Hide file tree
Showing 11 changed files with 118 additions and 57 deletions.
28 changes: 25 additions & 3 deletions llmfoundry/layers_registry.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Type
from typing import Callable, Type

import torch

from llmfoundry.utils.registry_utils import create_registry

# Layers
_norm_description = """The norms registry is used to register classes that implement normalization layers."""
_norm_description = (
'The norms registry is used to register classes that implement normalization layers.'
)
norms = create_registry('llmfoundry',
'norms',
generic_type=Type[torch.nn.Module],
Expand All @@ -25,7 +26,28 @@
entry_points=True,
description=_fc_description)

_attention_classes_description = (
'The attention_classes registry is used to register classes that implement attention layers. See '
+ 'attention.py for expected constructor signature.')
attention_classes = create_registry('llmfoundry',
'attention_classes',
generic_type=Type[torch.nn.Module],
entry_points=True,
description=_attention_classes_description)

_attention_implementations_description = (
'The attention_implementations registry is used to register functions that implement the attention operation.'
+ 'See attention.py for expected function signature.')
attention_implementations = create_registry(
'llmfoundry',
'attention_implementations',
generic_type=Callable,
entry_points=True,
description=_attention_implementations_description)

__all__ = [
'norms',
'attention_classes',
'attention_implementations',
'fcs',
]
7 changes: 3 additions & 4 deletions llmfoundry/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.models.layers.attention import (
ATTN_CLASS_REGISTRY, GroupedQueryAttention, MultiheadAttention,
MultiQueryAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
flash_attn_fn, scaled_multihead_dot_product_attention)
GroupedQueryAttention, MultiheadAttention, MultiQueryAttention,
attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn,
scaled_multihead_dot_product_attention)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import *
Expand All @@ -20,7 +20,6 @@
'attn_bias_shape',
'build_attn_bias',
'build_alibi_bias',
'ATTN_CLASS_REGISTRY',
'MPTMLP',
'MPTBlock',
'LPLayerNorm',
Expand Down
20 changes: 9 additions & 11 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from packaging import version
from torch import nn

from llmfoundry.layers_registry import (attention_classes,
attention_implementations)
from llmfoundry.models.layers.layer_builders import build_fc, build_norm


Expand Down Expand Up @@ -340,6 +342,7 @@ def flash_attn_fn(
return output, None, past_key_value


@attention_classes.register_class('grouped_query_attention')
class GroupedQueryAttention(nn.Module):
"""Grouped Query Attention (GQA) is a generalization of Multi-head (MHA).
Expand Down Expand Up @@ -433,12 +436,7 @@ def __init__(
device=device,
)

if self.attn_impl == 'flash':
self.attn_fn = flash_attn_fn
elif self.attn_impl == 'torch':
self.attn_fn = scaled_multihead_dot_product_attention
else:
raise ValueError(f'{attn_impl=} is an invalid setting.')
self.attn_fn = attention_implementations.get(self.attn_impl)

self.out_proj = build_fc(
name=fc_type,
Expand Down Expand Up @@ -573,6 +571,7 @@ def forward(
return self.out_proj(context), attn_weights, past_key_value


@attention_classes.register_class('multihead_attention')
class MultiheadAttention(GroupedQueryAttention):
"""Multi-head self attention.
Expand Down Expand Up @@ -613,6 +612,7 @@ def __init__(
)


@attention_classes.register_class('multiquery_attention')
class MultiQueryAttention(GroupedQueryAttention):
"""Multi-Query self attention.
Expand Down Expand Up @@ -741,8 +741,6 @@ def build_alibi_bias(
return alibi_bias.to(dtype=dtype)


ATTN_CLASS_REGISTRY = {
'multihead_attention': MultiheadAttention,
'multiquery_attention': MultiQueryAttention,
'grouped_query_attention': GroupedQueryAttention
}
attention_implementations.register('flash', func=flash_attn_fn)
attention_implementations.register('torch',
func=scaled_multihead_dot_product_attention)
41 changes: 22 additions & 19 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import torch
import torch.nn as nn

from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn
from llmfoundry.models.layers.layer_builders import build_norm
from llmfoundry.models.layers.layer_builders import (build_attention_layer,
build_norm)

try:
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
Expand Down Expand Up @@ -87,8 +87,6 @@ def __init__(
)
else:
assert isinstance(attn_config['attn_type'], str)
attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]

# Necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
args_to_exclude_in_attn_class = {
'attn_type', 'alibi', 'attn_uses_sequence_id', 'alibi_bias_max',
Expand All @@ -106,13 +104,16 @@ def __init__(
normalized_shape=d_model,
device=device,
)
self.attn = attn_class(
d_model=d_model,
n_heads=n_heads,
fc_type=fc_type,
device=device,
**attn_config_subset_for_attn_class,
bias=not no_bias,
self.attn = build_attention_layer(
name=attn_config['attn_type'],
attn_kwargs={
'd_model': d_model,
'n_heads': n_heads,
'fc_type': fc_type,
'device': device,
'bias': not no_bias,
**attn_config_subset_for_attn_class
},
)
self.norm_2 = None
if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']],
Expand Down Expand Up @@ -209,7 +210,6 @@ def __init__(
assert attn_config is not None
assert ffn_config is not None
assert isinstance(attn_config['attn_type'], str)
attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]

# necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
args_to_exclude_in_attn_class = {
Expand All @@ -227,13 +227,16 @@ def __init__(
normalized_shape=d_model,
device=device,
)
self.attn = attn_class(
d_model=d_model,
n_heads=n_heads,
fc_type=fc_type,
device=device,
**attn_config_subset_for_attn_class,
bias=not no_bias,
self.attn = build_attention_layer(
name=attn_config['attn_type'],
attn_kwargs={
'd_model': d_model,
'n_heads': n_heads,
'fc_type': fc_type,
'device': device,
'bias': not no_bias,
**attn_config_subset_for_attn_class
},
)
self.norm_2 = None
if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm',
Expand Down
12 changes: 11 additions & 1 deletion llmfoundry/models/layers/layer_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from llmfoundry.layers_registry import fcs, norms
from llmfoundry.layers_registry import attention_classes, fcs, norms
from llmfoundry.utils.registry_utils import construct_from_registry


Expand All @@ -25,6 +25,16 @@ def build_norm(
kwargs=kwargs)


def build_attention_layer(
name: str,
attn_kwargs: Dict[str, Any],
):
return construct_from_registry(name=name,
registry=attention_classes,
pre_validation_function=torch.nn.Module,
kwargs=attn_kwargs)


def build_fc(
name: str,
in_features: int,
Expand Down
12 changes: 6 additions & 6 deletions llmfoundry/models/utils/act_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

import torch

from llmfoundry.layers_registry import norms
from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY
from llmfoundry.layers_registry import attention_classes, norms
from llmfoundry.models.layers.blocks import FusedNormAttentionNorm, MPTBlock
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY

Expand All @@ -25,18 +24,19 @@ def get_act_ckpt_module(mod_name: str) -> Any:
"""Get the module type from the module name."""
if mod_name.lower() == 'mptblock':
mod_type = MPTBlock
elif mod_name in attention_classes:
mod_type = attention_classes.get(mod_name)
elif mod_name.lower() == 'norm_attn_norm':
mod_type = FusedNormAttentionNorm
elif mod_name in ATTN_CLASS_REGISTRY:
mod_type = ATTN_CLASS_REGISTRY[mod_name]
elif mod_name in FFN_CLASS_REGISTRY:
mod_type = FFN_CLASS_REGISTRY[mod_name]
elif mod_name in norms:
mod_type = norms.get(mod_name)
else:
msg = ', '.join(
list(ATTN_CLASS_REGISTRY.keys()) + list(FFN_CLASS_REGISTRY.keys()) +
list(norms.get_all()) + ['MPTBlock'])
list(attention_classes.get_all()) +
list(FFN_CLASS_REGISTRY.keys()) + list(norms.get_all()) +
['MPTBlock'])
raise ValueError(
f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.'
)
Expand Down
26 changes: 19 additions & 7 deletions llmfoundry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from transformers import PreTrainedTokenizerBase

from llmfoundry.interfaces import CallbackWithConfig
from llmfoundry.layers_registry import fcs, norms
from llmfoundry.layers_registry import (attention_classes,
attention_implementations, fcs, norms)
from llmfoundry.utils.registry_utils import create_registry

_loggers_description = (
Expand Down Expand Up @@ -85,25 +86,34 @@
entry_points=True,
description=_schedulers_description)

_models_description = """The models registry is used to register classes that implement the ComposerModel interface. The model
constructor should accept two arguments: an omegaconf DictConfig named `om_model_config` and a PreTrainedTokenizerBase named `tokenizer`.
Note: This will soon be updated to take in named kwargs instead of a config directly."""
_models_description = (
'The models registry is used to register classes that implement the ComposerModel interface. '
+
'The model constructor should accept two arguments: an omegaconf DictConfig named `om_model_config` and a PreTrainedTokenizerBase named `tokenizer`. '
+
'Note: This will soon be updated to take in named kwargs instead of a config directly.'
)
models = create_registry('llmfoundry',
'models',
generic_type=Type[ComposerModel],
entry_points=True,
description=_models_description)

_dataloaders_description = """The dataloaders registry is used to register functions that create a DataSpec. The function should take
a DictConfig, a PreTrainedTokenizerBase, and an int as arguments, and return a DataSpec."""
_dataloaders_description = (
'The dataloaders registry is used to register functions that create a DataSpec. The function should take '
+
'a DictConfig, a PreTrainedTokenizerBase, and an int as arguments, and return a DataSpec.'
)
dataloaders = create_registry(
'llmfoundry',
'dataloaders',
generic_type=Callable[[DictConfig, PreTrainedTokenizerBase, int], DataSpec],
entry_points=True,
description=_dataloaders_description)

_metrics_description = """The metrics registry is used to register classes that implement the torchmetrics.Metric interface."""
_metrics_description = (
'The metrics registry is used to register classes that implement the torchmetrics.Metric interface.'
)
metrics = create_registry('llmfoundry',
'metrics',
generic_type=Type[Metric],
Expand All @@ -121,5 +131,7 @@
'metrics',
'dataloaders',
'norms',
'attention_classes',
'attention_implementations',
'fcs',
]
1 change: 1 addition & 0 deletions llmfoundry/utils/registry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

T = TypeVar('T')
TypeBoundT = TypeVar('TypeBoundT', bound=Type)
CallableBoundT = TypeVar('CallableBoundT', bound=Callable[..., Any])


class TypedRegistry(catalogue.Registry, Generic[T]):
Expand Down
11 changes: 9 additions & 2 deletions tests/models/layers/test_flash_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from llmfoundry.models.layers import attention
from llmfoundry.models.layers.attention import (check_alibi_support, gen_slopes,
is_flash_v2_installed)
from llmfoundry.models.layers.layer_builders import build_attention_layer
from llmfoundry.models.mpt.modeling_mpt import (apply_sequence_id,
gen_attention_mask_in_length,
gen_flash_attn_padding_info,
Expand Down Expand Up @@ -120,9 +121,15 @@ def test_attn_impl(attn_impl_0: str,
]).to(device=device)

cfg.attn_impl = attn_impl_0
attn0 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device)
attn0 = build_attention_layer(
name=attn_type,
attn_kwargs=om.to_container(cfg), # type: ignore
).to(device)
cfg.attn_impl = attn_impl_1
attn1 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device)
attn1 = build_attention_layer(
name=attn_type,
attn_kwargs=om.to_container(cfg), # type: ignore
).to(device)

attn1.load_state_dict(attn0.state_dict())

Expand Down
15 changes: 11 additions & 4 deletions tests/models/test_rope_dail_vs_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from omegaconf import OmegaConf as om

from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.layers.layer_builders import build_attention_layer
from llmfoundry.models.mpt.modeling_mpt import (gen_flash_attn_padding_info,
gen_rotary_embedding)

Expand All @@ -21,8 +22,6 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'):
if not is_flash_v2_installed():
pytest.skip('dail implementation of rope requires flash attention 2.')

from llmfoundry.models.layers import attention

cfg = om.create({
'attn_impl': 'flash',
'd_model': 128,
Expand All @@ -37,8 +36,16 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'):
if attn_type == 'grouped_query_attention':
cfg.kv_n_heads = 2

attn0 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device)
attn1 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device)
attn0 = build_attention_layer(
name=attn_type,
attn_kwargs=om.to_container(
cfg), # type: ignore (to_container return broad type)
).to(device)
attn1 = build_attention_layer(
name=attn_type,
attn_kwargs=om.to_container(
cfg), # type: ignore (to_container return broad type)
).to(device)

attn1.load_state_dict(attn0.state_dict())
x0 = torch.randn(batch_size, seq_len, cfg.d_model).to(device)
Expand Down
Loading

0 comments on commit 560012b

Please sign in to comment.