Skip to content

Commit

Permalink
FC layer registry (#1093)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Apr 12, 2024
1 parent 4cd2324 commit ed3daef
Show file tree
Hide file tree
Showing 10 changed files with 70 additions and 38 deletions.
11 changes: 11 additions & 0 deletions llmfoundry/layers_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,18 @@
generic_type=Type[torch.nn.Module],
entry_points=True,
description=_norm_description)
_fc_description = (
'The fully connected layers registry is used to register classes that implement fully connected layers (i.e. torch.nn.Linear).'
+
'These classes should take in_features and out_features in as args, at a minimum.'
)
fcs = create_registry('llmfoundry',
'fcs',
generic_type=Type[torch.nn.Module],
entry_points=True,
description=_fc_description)

__all__ = [
'norms',
'fcs',
]
3 changes: 1 addition & 2 deletions llmfoundry/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
flash_attn_fn, scaled_multihead_dot_product_attention)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.fc import *
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
from llmfoundry.models.layers.norm import LPLayerNorm

Expand All @@ -24,7 +24,6 @@
'MPTMLP',
'MPTBlock',
'LPLayerNorm',
'FC_CLASS_REGISTRY',
'SharedEmbedding',
'FFN_CLASS_REGISTRY',
'build_ffn',
Expand Down
21 changes: 11 additions & 10 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from packaging import version
from torch import nn

from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.layer_builders import build_norm
from llmfoundry.models.layers.layer_builders import build_fc, build_norm


def is_flash_v2_installed(v2_version: str = '2.0.0'):
Expand Down Expand Up @@ -406,10 +405,11 @@ def __init__(
'bias': bias,
}
fc_kwargs['device'] = device
self.Wqkv = FC_CLASS_REGISTRY[fc_type](
self.d_model,
self.d_model + 2 * self.kv_n_heads * self.head_dim,
**fc_kwargs,
self.Wqkv = build_fc(
name=fc_type,
in_features=self.d_model,
out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim,
fc_kwargs=fc_kwargs,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = [
Expand Down Expand Up @@ -440,10 +440,11 @@ def __init__(
else:
raise ValueError(f'{attn_impl=} is an invalid setting.')

self.out_proj = FC_CLASS_REGISTRY[fc_type](
self.d_model,
self.d_model,
**fc_kwargs,
self.out_proj = build_fc(
name=fc_type,
in_features=self.d_model,
out_features=self.d_model,
fc_kwargs=fc_kwargs,
)
self.out_proj._is_residual = True

Expand Down
8 changes: 4 additions & 4 deletions llmfoundry/models/layers/fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from torch import nn

FC_CLASS_REGISTRY = {
'torch': nn.Linear,
}
from llmfoundry.layers_registry import fcs

fcs.register('torch', func=nn.Linear)

try:
import transformer_engine.pytorch as te
FC_CLASS_REGISTRY['te'] = te.Linear
fcs.register('te', func=te.Linear)
except:
pass
31 changes: 17 additions & 14 deletions llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Shard

from llmfoundry.models.layers.dmoe import dMoE
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.layer_builders import build_fc

try:
import transformer_engine.pytorch as te
Expand Down Expand Up @@ -52,7 +52,7 @@ def resolve_ffn_act_fn(
config = deepcopy(config)
name = config.pop('name')
if not hasattr(torch.nn.functional, name):
raise ValueError(f'Unrecognised activation function name ({name}).')
raise ValueError(f'Unrecognized activation function name ({name}).')
act = getattr(torch.nn.functional, name)
return partial(act, **config)

Expand Down Expand Up @@ -121,16 +121,18 @@ def __init__(

self.fc_kwargs['device'] = device

self.up_proj = FC_CLASS_REGISTRY[fc_type](
d_model,
ffn_hidden_size,
**self.fc_kwargs,
self.up_proj = build_fc(
name=fc_type,
in_features=d_model,
out_features=ffn_hidden_size,
fc_kwargs=self.fc_kwargs,
)
self.act = act_fn
self.down_proj = FC_CLASS_REGISTRY[fc_type](
ffn_hidden_size,
d_model,
**self.fc_kwargs,
self.down_proj = build_fc(
name=fc_type,
in_features=ffn_hidden_size,
out_features=d_model,
fc_kwargs=self.fc_kwargs,
)
self.down_proj._is_residual = True

Expand Down Expand Up @@ -159,10 +161,11 @@ def __init__(
device=device,
bias=bias,
)
self.gate_proj = FC_CLASS_REGISTRY[fc_type](
d_model,
self.up_proj.out_features,
**self.fc_kwargs,
self.gate_proj = build_fc(
name=fc_type,
in_features=d_model,
out_features=self.up_proj.out_features,
fc_kwargs=self.fc_kwargs,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
22 changes: 20 additions & 2 deletions llmfoundry/models/layers/layer_builders.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import torch

from llmfoundry.layers_registry import norms
from llmfoundry.layers_registry import fcs, norms
from llmfoundry.utils.registry_utils import construct_from_registry


Expand All @@ -23,3 +23,21 @@ def build_norm(
registry=norms,
pre_validation_function=torch.nn.Module,
kwargs=kwargs)


def build_fc(
name: str,
in_features: int,
out_features: int,
fc_kwargs: Dict[str, Any],
):
kwargs = {
'in_features': in_features,
'out_features': out_features,
**fc_kwargs,
}

return construct_from_registry(name=name,
registry=fcs,
pre_validation_function=torch.nn.Module,
kwargs=kwargs)
3 changes: 1 addition & 2 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
# HuggingFace can detect all the needed files to copy into its modules folder.
# Otherwise, certain modules are missing.
# isort: off
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY # type: ignore (see note)
from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note)
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note)
from llmfoundry.models.layers.layer_builders import build_norm, build_fc # type: ignore (see note)
from llmfoundry.models.layers.dmoe import dMoE # type: ignore (see note)
from llmfoundry.models.layers.layer_builders import build_norm # type: ignore (see note)
from llmfoundry.layers_registry import norms # type: ignore (see note)
from llmfoundry.utils.registry_utils import construct_from_registry # type: ignore (see note)

Expand Down
5 changes: 2 additions & 3 deletions llmfoundry/models/utils/param_init_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
from torch import nn
from torch.distributed._tensor import DTensor

from llmfoundry.layers_registry import norms
from llmfoundry.layers_registry import fcs, norms
from llmfoundry.models.layers.dmoe import GLU, MLP
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY

try:
import transformer_engine.pytorch as te
Expand Down Expand Up @@ -182,7 +181,7 @@ def generic_param_init_fn_(
f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}'
)

if isinstance(module, tuple(set(FC_CLASS_REGISTRY.values()))):
if isinstance(module, tuple(set([fcs.get(n) for n in fcs.get_all()]))):
# Linear
if hasattr(module, '_fused'):
fused_init_helper_(module, init_fn_)
Expand Down
3 changes: 2 additions & 1 deletion llmfoundry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from transformers import PreTrainedTokenizerBase

from llmfoundry.interfaces import CallbackWithConfig
from llmfoundry.layers_registry import norms
from llmfoundry.layers_registry import fcs, norms
from llmfoundry.utils.registry_utils import create_registry

_loggers_description = (
Expand Down Expand Up @@ -121,4 +121,5 @@
'metrics',
'dataloaders',
'norms',
'fcs',
]
1 change: 1 addition & 0 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_expected_registries_exist():
'metrics',
'models',
'norms',
'fcs',
}

assert existing_registries == expected_registry_names
Expand Down

0 comments on commit ed3daef

Please sign in to comment.