Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attention layer registry #1094

Merged
merged 12 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions llmfoundry/layers_registry.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,42 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Type
from typing import Callable, Type

import torch

from llmfoundry.utils.registry_utils import create_registry

# Layers
_norm_description = """The norms registry is used to register classes that implement normalization layers."""
_norm_description = (
'The norms registry is used to register classes that implement normalization layers.'
)
norms = create_registry('llmfoundry',
'norms',
generic_type=Type[torch.nn.Module],
entry_points=True,
description=_norm_description)

_attention_classes_description = (
'The attention_classes registry is used to register classes that implement attention layers. See '
+ 'attention.py for expected constructor signature.')
attention_classes = create_registry('llmfoundry',
'attention_classes',
generic_type=Type[torch.nn.Module],
entry_points=True,
description=_attention_classes_description)

_attention_implementations_description = (
'The attention_implementations registry is used to register functions that implement the attention operation.'
+ 'See attention.py for expected function signature.')
attention_implementations = create_registry(
'llmfoundry',
'attention_implementations',
generic_type=Callable,
entry_points=True,
description=_attention_implementations_description)

__all__ = [
'norms',
'attention_classes',
'attention_implementations',
]
7 changes: 3 additions & 4 deletions llmfoundry/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.models.layers.attention import (
ATTN_CLASS_REGISTRY, GroupedQueryAttention, MultiheadAttention,
MultiQueryAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
flash_attn_fn, scaled_multihead_dot_product_attention)
GroupedQueryAttention, MultiheadAttention, MultiQueryAttention,
attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn,
scaled_multihead_dot_product_attention)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
Expand All @@ -20,7 +20,6 @@
'attn_bias_shape',
'build_attn_bias',
'build_alibi_bias',
'ATTN_CLASS_REGISTRY',
'MPTMLP',
'MPTBlock',
'LPLayerNorm',
Expand Down
20 changes: 9 additions & 11 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from packaging import version
from torch import nn

from llmfoundry.layers_registry import (attention_classes,
attention_implementations)
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.layer_builders import build_norm

Expand Down Expand Up @@ -341,6 +343,7 @@ def flash_attn_fn(
return output, None, past_key_value


@attention_classes.register_class('grouped_query_attention')
class GroupedQueryAttention(nn.Module):
"""Grouped Query Attention (GQA) is a generalization of Multi-head (MHA).

Expand Down Expand Up @@ -433,12 +436,7 @@ def __init__(
device=device,
)

if self.attn_impl == 'flash':
self.attn_fn = flash_attn_fn
elif self.attn_impl == 'torch':
self.attn_fn = scaled_multihead_dot_product_attention
else:
raise ValueError(f'{attn_impl=} is an invalid setting.')
self.attn_fn = attention_implementations.get(self.attn_impl)

self.out_proj = FC_CLASS_REGISTRY[fc_type](
self.d_model,
Expand Down Expand Up @@ -572,6 +570,7 @@ def forward(
return self.out_proj(context), attn_weights, past_key_value


@attention_classes.register_class('multihead_attention')
class MultiheadAttention(GroupedQueryAttention):
"""Multi-head self attention.

Expand Down Expand Up @@ -612,6 +611,7 @@ def __init__(
)


@attention_classes.register_class('multiquery_attention')
class MultiQueryAttention(GroupedQueryAttention):
"""Multi-Query self attention.

Expand Down Expand Up @@ -740,8 +740,6 @@ def build_alibi_bias(
return alibi_bias.to(dtype=dtype)


ATTN_CLASS_REGISTRY = {
'multihead_attention': MultiheadAttention,
'multiquery_attention': MultiQueryAttention,
'grouped_query_attention': GroupedQueryAttention
}
attention_implementations.register('flash', func=flash_attn_fn)
attention_implementations.register('torch',
func=scaled_multihead_dot_product_attention)
41 changes: 22 additions & 19 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import torch
import torch.nn as nn

from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn
from llmfoundry.models.layers.layer_builders import build_norm
from llmfoundry.models.layers.layer_builders import (build_attention_layer,
build_norm)

try:
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
Expand Down Expand Up @@ -87,8 +87,6 @@ def __init__(
)
else:
assert isinstance(attn_config['attn_type'], str)
attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]

# Necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
args_to_exclude_in_attn_class = {
'attn_type', 'alibi', 'attn_uses_sequence_id', 'alibi_bias_max',
Expand All @@ -106,13 +104,16 @@ def __init__(
normalized_shape=d_model,
device=device,
)
self.attn = attn_class(
d_model=d_model,
n_heads=n_heads,
fc_type=fc_type,
device=device,
**attn_config_subset_for_attn_class,
bias=not no_bias,
self.attn = build_attention_layer(
name=attn_config['attn_type'],
attn_kwargs={
'd_model': d_model,
'n_heads': n_heads,
'fc_type': fc_type,
'device': device,
'bias': not no_bias,
**attn_config_subset_for_attn_class
},
)
self.norm_2 = None
if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']],
Expand Down Expand Up @@ -209,7 +210,6 @@ def __init__(
assert attn_config is not None
assert ffn_config is not None
assert isinstance(attn_config['attn_type'], str)
attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]

# necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
args_to_exclude_in_attn_class = {
Expand All @@ -227,13 +227,16 @@ def __init__(
normalized_shape=d_model,
device=device,
)
self.attn = attn_class(
d_model=d_model,
n_heads=n_heads,
fc_type=fc_type,
device=device,
**attn_config_subset_for_attn_class,
bias=not no_bias,
self.attn = build_attention_layer(
name=attn_config['attn_type'],
attn_kwargs={
'd_model': d_model,
'n_heads': n_heads,
'fc_type': fc_type,
'device': device,
'bias': not no_bias,
**attn_config_subset_for_attn_class
},
)
self.norm_2 = None
if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm',
Expand Down
14 changes: 12 additions & 2 deletions llmfoundry/models/layers/layer_builders.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import torch

from llmfoundry.layers_registry import norms
from llmfoundry.layers_registry import attention_classes, norms
from llmfoundry.utils.registry_utils import construct_from_registry


Expand All @@ -23,3 +23,13 @@ def build_norm(
registry=norms,
pre_validation_function=torch.nn.Module,
kwargs=kwargs)


def build_attention_layer(
name: str,
attn_kwargs: Dict[str, Any],
):
return construct_from_registry(name=name,
registry=attention_classes,
pre_validation_function=torch.nn.Module,
kwargs=attn_kwargs)
12 changes: 6 additions & 6 deletions llmfoundry/models/utils/act_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

import torch

from llmfoundry.layers_registry import norms
from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY
from llmfoundry.layers_registry import attention_classes, norms
from llmfoundry.models.layers.blocks import FusedNormAttentionNorm, MPTBlock
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY

Expand All @@ -25,18 +24,19 @@ def get_act_ckpt_module(mod_name: str) -> Any:
"""Get the module type from the module name."""
if mod_name.lower() == 'mptblock':
mod_type = MPTBlock
elif mod_name in attention_classes:
mod_type = attention_classes.get(mod_name)
elif mod_name.lower() == 'norm_attn_norm':
mod_type = FusedNormAttentionNorm
elif mod_name in ATTN_CLASS_REGISTRY:
mod_type = ATTN_CLASS_REGISTRY[mod_name]
elif mod_name in FFN_CLASS_REGISTRY:
mod_type = FFN_CLASS_REGISTRY[mod_name]
elif mod_name in norms:
mod_type = norms.get(mod_name)
else:
msg = ', '.join(
list(ATTN_CLASS_REGISTRY.keys()) + list(FFN_CLASS_REGISTRY.keys()) +
list(norms.get_all()) + ['MPTBlock'])
list(attention_classes.get_all()) +
list(FFN_CLASS_REGISTRY.keys()) + list(norms.get_all()) +
['MPTBlock'])
raise ValueError(
f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.'
)
Expand Down
26 changes: 19 additions & 7 deletions llmfoundry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from transformers import PreTrainedTokenizerBase

from llmfoundry.interfaces import CallbackWithConfig
from llmfoundry.layers_registry import norms
from llmfoundry.layers_registry import (attention_classes,
attention_implementations, norms)
from llmfoundry.utils.registry_utils import create_registry

_loggers_description = (
Expand Down Expand Up @@ -85,25 +86,34 @@
entry_points=True,
description=_schedulers_description)

_models_description = """The models registry is used to register classes that implement the ComposerModel interface. The model
constructor should accept two arguments: an omegaconf DictConfig named `om_model_config` and a PreTrainedTokenizerBase named `tokenizer`.
Note: This will soon be updated to take in named kwargs instead of a config directly."""
_models_description = (
'The models registry is used to register classes that implement the ComposerModel interface. '
+
'The model constructor should accept two arguments: an omegaconf DictConfig named `om_model_config` and a PreTrainedTokenizerBase named `tokenizer`. '
+
'Note: This will soon be updated to take in named kwargs instead of a config directly.'
)
models = create_registry('llmfoundry',
'models',
generic_type=Type[ComposerModel],
entry_points=True,
description=_models_description)

_dataloaders_description = """The dataloaders registry is used to register functions that create a DataSpec. The function should take
a DictConfig, a PreTrainedTokenizerBase, and an int as arguments, and return a DataSpec."""
_dataloaders_description = (
'The dataloaders registry is used to register functions that create a DataSpec. The function should take '
+
'a DictConfig, a PreTrainedTokenizerBase, and an int as arguments, and return a DataSpec.'
)
dataloaders = create_registry(
'llmfoundry',
'dataloaders',
generic_type=Callable[[DictConfig, PreTrainedTokenizerBase, int], DataSpec],
entry_points=True,
description=_dataloaders_description)

_metrics_description = """The metrics registry is used to register classes that implement the torchmetrics.Metric interface."""
_metrics_description = (
'The metrics registry is used to register classes that implement the torchmetrics.Metric interface.'
)
metrics = create_registry('llmfoundry',
'metrics',
generic_type=Type[Metric],
Expand All @@ -121,4 +131,6 @@
'metrics',
'dataloaders',
'norms',
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
'attention_classes',
'attention_implementations',
]
1 change: 1 addition & 0 deletions llmfoundry/utils/registry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

T = TypeVar('T')
TypeBoundT = TypeVar('TypeBoundT', bound=Type)
CallableBoundT = TypeVar('CallableBoundT', bound=Callable[..., Any])


class TypedRegistry(catalogue.Registry, Generic[T]):
Expand Down
11 changes: 9 additions & 2 deletions tests/models/layers/test_flash_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from llmfoundry.models.layers import attention
from llmfoundry.models.layers.attention import (check_alibi_support, gen_slopes,
is_flash_v2_installed)
from llmfoundry.models.layers.layer_builders import build_attention_layer
from llmfoundry.models.mpt.modeling_mpt import (apply_sequence_id,
gen_attention_mask_in_length,
gen_flash_attn_padding_info,
Expand Down Expand Up @@ -120,9 +121,15 @@ def test_attn_impl(attn_impl_0: str,
]).to(device=device)

cfg.attn_impl = attn_impl_0
attn0 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device)
attn0 = build_attention_layer(
name=attn_type,
attn_kwargs=om.to_container(cfg), # type: ignore
).to(device)
cfg.attn_impl = attn_impl_1
attn1 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device)
attn1 = build_attention_layer(
name=attn_type,
attn_kwargs=om.to_container(cfg), # type: ignore
).to(device)

attn1.load_state_dict(attn0.state_dict())

Expand Down
13 changes: 9 additions & 4 deletions tests/models/test_rope_dail_vs_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from omegaconf import OmegaConf as om

from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.layers.layer_builders import build_attention_layer
from llmfoundry.models.mpt.modeling_mpt import (gen_flash_attn_padding_info,
gen_rotary_embedding)

Expand All @@ -21,8 +22,6 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'):
if not is_flash_v2_installed():
pytest.skip('dail implementation of rope requires flash attention 2.')

from llmfoundry.models.layers import attention

cfg = om.create({
'attn_impl': 'flash',
'd_model': 128,
Expand All @@ -37,8 +36,14 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'):
if attn_type == 'grouped_query_attention':
cfg.kv_n_heads = 2

attn0 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device)
attn1 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device)
attn0 = build_attention_layer(
name=attn_type,
attn_kwargs=om.to_container(cfg), # type: ignore
).to(device)
attn1 = build_attention_layer(
name=attn_type,
attn_kwargs=om.to_container(cfg), # type: ignore
).to(device)
dakinggg marked this conversation as resolved.
Show resolved Hide resolved

attn1.load_state_dict(attn0.state_dict())
x0 = torch.randn(batch_size, seq_len, cfg.d_model).to(device)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def test_expected_registries_exist():
'metrics',
'models',
'norms',
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
'attention_classes',
'attention_implementations',
}

assert existing_registries == expected_registry_names
Expand Down
Loading