Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FFN layer registry #1095

Merged
merged 38 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
flash_attn_fn, scaled_multihead_dot_product_attention)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
from llmfoundry.models.layers.ffn import MPTMLP
from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig,
MPTForCausalLM, MPTModel, MPTPreTrainedModel)
from llmfoundry.tokenizers import TiktokenTokenizerWrapper
Expand All @@ -37,9 +37,7 @@
'build_finetuning_dataloader',
'Seq2SeqFinetuningCollator',
'MPTBlock',
'FFN_CLASS_REGISTRY',
'MPTMLP',
'build_ffn',
'MPTConfig',
'MPTPreTrainedModel',
'MPTModel',
Expand Down
22 changes: 21 additions & 1 deletion llmfoundry/layers_registry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Type
from typing import Callable, Type

import torch

Expand All @@ -15,6 +15,26 @@
entry_points=True,
description=_norm_description)

_ffns_description = (
'The ffns registry is used to register functions that build ffn layers.' +
'See ffn.py for examples.')
ffns = create_registry('llmfoundry',
'ffns',
generic_type=Callable,
entry_points=True,
description=_ffns_description)

_ffns_with_norm_description = (
'The ffns_with_norm registry is used to register functions that build ffn layers that apply a normalization layer.'
+ 'See ffn.py for examples.')
ffns_with_norm = create_registry('llmfoundry',
'ffns_with_norm',
generic_type=Callable,
entry_points=True,
description=_ffns_with_norm_description)

__all__ = [
'norms',
'ffns',
'ffns_with_norm',
]
4 changes: 1 addition & 3 deletions llmfoundry/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
from llmfoundry.models.layers.ffn import MPTMLP
from llmfoundry.models.layers.norm import LPLayerNorm

__all__ = [
Expand All @@ -26,6 +26,4 @@
'LPLayerNorm',
'FC_CLASS_REGISTRY',
'SharedEmbedding',
'FFN_CLASS_REGISTRY',
'build_ffn',
]
22 changes: 16 additions & 6 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import torch.nn as nn

from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn
from llmfoundry.models.layers.layer_builders import build_norm
from llmfoundry.models.layers.layer_builders import build_ffn, build_norm

try:
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
Expand Down Expand Up @@ -77,6 +76,7 @@ def __init__(
self.norm_attn_norm = FusedNormAttentionNorm(
d_model=d_model,
n_heads=n_heads,
expansion_ratio=expansion_ratio,
attn_config=attn_config,
ffn_config=ffn_config,
fc_type=fc_type,
Expand Down Expand Up @@ -115,8 +115,7 @@ def __init__(
bias=not no_bias,
)
self.norm_2 = None
if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']],
'_has_norm', False):
if not getattr(self.ffn, '_has_norm', False):
self.norm_2 = build_norm(
name=norm_type.lower(),
normalized_shape=d_model,
Expand Down Expand Up @@ -196,6 +195,7 @@ def __init__(
self,
d_model: int,
n_heads: int,
expansion_ratio: float,
attn_config: Optional[Dict] = None,
ffn_config: Optional[Dict] = None,
fc_type: str = 'torch',
Expand Down Expand Up @@ -235,9 +235,19 @@ def __init__(
**attn_config_subset_for_attn_class,
bias=not no_bias,
)

ffn_type = ffn_config.pop('ffn_type')
self.ffn = build_ffn(
name=ffn_type,
d_model=d_model,
expansion_ratio=expansion_ratio,
device=device,
bias=not no_bias,
ffn_kwargs=ffn_config,
)

self.norm_2 = None
if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm',
False):
if not getattr(self.ffn, '_has_norm', False):
self.norm_2 = build_norm(
name=norm_type.lower(),
normalized_shape=d_model,
Expand Down
Loading
Loading