Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Apr 12, 2024
2 parents f7e4fec + ed3daef commit 73db66c
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 37 deletions.
11 changes: 11 additions & 0 deletions llmfoundry/layers_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@
generic_type=Type[torch.nn.Module],
entry_points=True,
description=_norm_description)
_fc_description = (
'The fully connected layers registry is used to register classes that implement fully connected layers (i.e. torch.nn.Linear).'
+
'These classes should take in_features and out_features in as args, at a minimum.'
)
fcs = create_registry('llmfoundry',
'fcs',
generic_type=Type[torch.nn.Module],
entry_points=True,
description=_fc_description)

_ffns_description = (
'The ffns registry is used to register functions that build ffn layers.' +
Expand Down Expand Up @@ -48,4 +58,5 @@
'ffns',
'ffns_with_norm',
'ffns_with_megablocks',
'fcs',
]
3 changes: 1 addition & 2 deletions llmfoundry/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
flash_attn_fn, scaled_multihead_dot_product_attention)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.fc import *
from llmfoundry.models.layers.ffn import MPTMLP
from llmfoundry.models.layers.norm import LPLayerNorm

Expand All @@ -24,6 +24,5 @@
'MPTMLP',
'MPTBlock',
'LPLayerNorm',
'FC_CLASS_REGISTRY',
'SharedEmbedding',
]
21 changes: 11 additions & 10 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from packaging import version
from torch import nn

from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.layer_builders import build_norm
from llmfoundry.models.layers.layer_builders import build_fc, build_norm


def is_flash_v2_installed(v2_version: str = '2.0.0'):
Expand Down Expand Up @@ -406,10 +405,11 @@ def __init__(
'bias': bias,
}
fc_kwargs['device'] = device
self.Wqkv = FC_CLASS_REGISTRY[fc_type](
self.d_model,
self.d_model + 2 * self.kv_n_heads * self.head_dim,
**fc_kwargs,
self.Wqkv = build_fc(
name=fc_type,
in_features=self.d_model,
out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim,
fc_kwargs=fc_kwargs,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = [
Expand Down Expand Up @@ -440,10 +440,11 @@ def __init__(
else:
raise ValueError(f'{attn_impl=} is an invalid setting.')

self.out_proj = FC_CLASS_REGISTRY[fc_type](
self.d_model,
self.d_model,
**fc_kwargs,
self.out_proj = build_fc(
name=fc_type,
in_features=self.d_model,
out_features=self.d_model,
fc_kwargs=fc_kwargs,
)
self.out_proj._is_residual = True

Expand Down
8 changes: 4 additions & 4 deletions llmfoundry/models/layers/fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from torch import nn

FC_CLASS_REGISTRY = {
'torch': nn.Linear,
}
from llmfoundry.layers_registry import fcs

fcs.register('torch', func=nn.Linear)

try:
import transformer_engine.pytorch as te
FC_CLASS_REGISTRY['te'] = te.Linear
fcs.register('te', func=te.Linear)
except:
pass
31 changes: 17 additions & 14 deletions llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from llmfoundry.layers_registry import (ffns, ffns_with_megablocks,
ffns_with_norm)
from llmfoundry.models.layers.dmoe import dMoE
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.layer_builders import build_fc

try:
import transformer_engine.pytorch as te
Expand Down Expand Up @@ -55,7 +55,7 @@ def resolve_ffn_act_fn(
config = deepcopy(config)
name = config.pop('name')
if not hasattr(torch.nn.functional, name):
raise ValueError(f'Unrecognised activation function name ({name}).')
raise ValueError(f'Unrecognized activation function name ({name}).')
act = getattr(torch.nn.functional, name)
return partial(act, **config)

Expand Down Expand Up @@ -124,16 +124,18 @@ def __init__(

self.fc_kwargs['device'] = device

self.up_proj = FC_CLASS_REGISTRY[fc_type](
d_model,
ffn_hidden_size,
**self.fc_kwargs,
self.up_proj = build_fc(
name=fc_type,
in_features=d_model,
out_features=ffn_hidden_size,
fc_kwargs=self.fc_kwargs,
)
self.act = act_fn
self.down_proj = FC_CLASS_REGISTRY[fc_type](
ffn_hidden_size,
d_model,
**self.fc_kwargs,
self.down_proj = build_fc(
name=fc_type,
in_features=ffn_hidden_size,
out_features=d_model,
fc_kwargs=self.fc_kwargs,
)
self.down_proj._is_residual = True

Expand Down Expand Up @@ -162,10 +164,11 @@ def __init__(
device=device,
bias=bias,
)
self.gate_proj = FC_CLASS_REGISTRY[fc_type](
d_model,
self.up_proj.out_features,
**self.fc_kwargs,
self.gate_proj = build_fc(
name=fc_type,
in_features=d_model,
out_features=self.up_proj.out_features,
fc_kwargs=self.fc_kwargs,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
20 changes: 19 additions & 1 deletion llmfoundry/models/layers/layer_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from llmfoundry.layers_registry import (ffns, ffns_with_megablocks,
from llmfoundry.layers_registry import (fcs, ffns, ffns_with_megablocks,
ffns_with_norm, norms)
from llmfoundry.utils.registry_utils import construct_from_registry

Expand Down Expand Up @@ -67,3 +67,21 @@ def _validation_function(maybe_module: Any):
result._uses_megablocks = True

return result


def build_fc(
name: str,
in_features: int,
out_features: int,
fc_kwargs: Dict[str, Any],
):
kwargs = {
'in_features': in_features,
'out_features': out_features,
**fc_kwargs,
}

return construct_from_registry(name=name,
registry=fcs,
pre_validation_function=torch.nn.Module,
kwargs=kwargs)
3 changes: 1 addition & 2 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
# HuggingFace can detect all the needed files to copy into its modules folder.
# Otherwise, certain modules are missing.
# isort: off
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY # type: ignore (see note)
from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note)
from llmfoundry.models.layers.layer_builders import build_norm, build_ffn # type: ignore (see note)
from llmfoundry.models.layers.layer_builders import build_norm, build_fc, build_ffn # type: ignore (see note)
from llmfoundry.models.layers.dmoe import dMoE # type: ignore (see note)
from llmfoundry.layers_registry import norms # type: ignore (see note)
from llmfoundry.utils.registry_utils import construct_from_registry # type: ignore (see note)
Expand Down
5 changes: 2 additions & 3 deletions llmfoundry/models/utils/param_init_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
from torch import nn
from torch.distributed._tensor import DTensor

from llmfoundry.layers_registry import norms
from llmfoundry.layers_registry import fcs, norms
from llmfoundry.models.layers.dmoe import GLU, MLP
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY

try:
import transformer_engine.pytorch as te
Expand Down Expand Up @@ -182,7 +181,7 @@ def generic_param_init_fn_(
f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}'
)

if isinstance(module, tuple(set(FC_CLASS_REGISTRY.values()))):
if isinstance(module, tuple(set([fcs.get(n) for n in fcs.get_all()]))):
# Linear
if hasattr(module, '_fused'):
fused_init_helper_(module, init_fn_)
Expand Down
3 changes: 2 additions & 1 deletion llmfoundry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from transformers import PreTrainedTokenizerBase

from llmfoundry.interfaces import CallbackWithConfig
from llmfoundry.layers_registry import (ffns, ffns_with_megablocks,
from llmfoundry.layers_registry import (fcs, ffns, ffns_with_megablocks,
ffns_with_norm, norms)
from llmfoundry.utils.registry_utils import create_registry

Expand Down Expand Up @@ -125,4 +125,5 @@
'ffns',
'ffns_with_norm',
'ffns_with_megablocks',
'fcs',
]
1 change: 1 addition & 0 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_expected_registries_exist():
'norms',
'ffns',
'ffns_with_norm',
'fcs',
}

assert existing_registries == expected_registry_names
Expand Down

0 comments on commit 73db66c

Please sign in to comment.