Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FFN layer registry #1095

Merged
merged 38 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
flash_attn_fn, scaled_multihead_dot_product_attention)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
from llmfoundry.models.layers.ffn import MPTMLP
from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig,
MPTForCausalLM, MPTModel, MPTPreTrainedModel)
from llmfoundry.tokenizers import TiktokenTokenizerWrapper
Expand All @@ -37,9 +37,7 @@
'build_finetuning_dataloader',
'Seq2SeqFinetuningCollator',
'MPTBlock',
'FFN_CLASS_REGISTRY',
'MPTMLP',
'build_ffn',
'MPTConfig',
'MPTPreTrainedModel',
'MPTModel',
Expand Down
31 changes: 31 additions & 0 deletions llmfoundry/layers_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,34 @@
entry_points=True,
description=_fc_description)

_ffns_description = (
'The ffns registry is used to register functions that build ffn layers.' +
'See ffn.py for examples.')
ffns = create_registry('llmfoundry',
'ffns',
generic_type=Callable,
entry_points=True,
description=_ffns_description)

_ffns_with_norm_description = (
'The ffns_with_norm registry is used to register functions that build ffn layers that apply a normalization layer.'
+ 'See ffn.py for examples.')
ffns_with_norm = create_registry('llmfoundry',
'ffns_with_norm',
generic_type=Callable,
entry_points=True,
description=_ffns_with_norm_description)

_ffns_with_megablocks_description = (
'The ffns_with_megablocks registry is used to register functions that build ffn layers using MegaBlocks.'
+ 'See ffn.py for examples.')
ffns_with_megablocks = create_registry(
'llmfoundry',
'ffns_with_megablocks',
generic_type=Callable,
entry_points=True,
description=_ffns_with_megablocks_description)

_attention_classes_description = (
'The attention_classes registry is used to register classes that implement attention layers. See '
+ 'attention.py for expected constructor signature.')
Expand All @@ -47,6 +75,9 @@

__all__ = [
'norms',
'ffns',
'ffns_with_norm',
'ffns_with_megablocks',
'attention_classes',
'attention_implementations',
'fcs',
Expand Down
4 changes: 1 addition & 3 deletions llmfoundry/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import *
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
from llmfoundry.models.layers.ffn import MPTMLP
from llmfoundry.models.layers.norm import LPLayerNorm

__all__ = [
Expand All @@ -24,6 +24,4 @@
'MPTBlock',
'LPLayerNorm',
'SharedEmbedding',
'FFN_CLASS_REGISTRY',
'build_ffn',
]
23 changes: 13 additions & 10 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import torch
import torch.nn as nn

from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn
from llmfoundry.layers_registry import ffns_with_norm
from llmfoundry.models.layers.layer_builders import (build_attention_layer,
build_norm)
build_ffn, build_norm)

try:
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
Expand Down Expand Up @@ -73,12 +73,15 @@ def __init__(
del kwargs # unused, just to capture any extra args from the config
super().__init__()

ffn_type = ffn_config['ffn_type']
ffn_has_norm = ffn_type in ffns_with_norm

if self.fuse_norm_attn_norm:
self.norm_attn_norm = FusedNormAttentionNorm(
d_model=d_model,
n_heads=n_heads,
attn_config=attn_config,
ffn_config=ffn_config,
ffn_has_norm=ffn_has_norm,
fc_type=fc_type,
resid_pdrop=resid_pdrop,
norm_type=norm_type,
Expand Down Expand Up @@ -116,21 +119,22 @@ def __init__(
},
)
self.norm_2 = None
if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']],
'_has_norm', False):
if not ffn_has_norm:
self.norm_2 = build_norm(
name=norm_type.lower(),
normalized_shape=d_model,
device=device,
)

self.ffn = build_ffn(
name=ffn_type,
d_model=d_model,
expansion_ratio=expansion_ratio,
device=device,
bias=not no_bias,
**ffn_config,
ffn_kwargs=ffn_config,
)

self.resid_attn_dropout = nn.Dropout(resid_pdrop)
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
self.use_pad_tok_in_ffn = use_pad_tok_in_ffn
Expand Down Expand Up @@ -198,7 +202,7 @@ def __init__(
d_model: int,
n_heads: int,
attn_config: Optional[Dict] = None,
ffn_config: Optional[Dict] = None,
ffn_has_norm: bool = False,
fc_type: str = 'torch',
resid_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
Expand All @@ -208,7 +212,6 @@ def __init__(
):
super().__init__()
assert attn_config is not None
assert ffn_config is not None
assert isinstance(attn_config['attn_type'], str)

# necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
Expand Down Expand Up @@ -238,9 +241,9 @@ def __init__(
**attn_config_subset_for_attn_class
},
)

self.norm_2 = None
if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm',
False):
if not ffn_has_norm:
self.norm_2 = build_norm(
name=norm_type.lower(),
normalized_shape=d_model,
Expand Down
Loading
Loading