Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FC layer registry #1093

Merged
merged 6 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions llmfoundry/layers_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,18 @@
generic_type=Type[torch.nn.Module],
entry_points=True,
description=_norm_description)
_fc_description = (
'The fully connected layers registry is used to register classes that implement fully connected layers (i.e. torch.nn.Linear).'
+
'These classes should take in_features and out_features in as args, at a minimum.'
)
fcs = create_registry('llmfoundry',
'fcs',
generic_type=Type[torch.nn.Module],
entry_points=True,
description=_fc_description)

__all__ = [
'norms',
'fcs',
]
3 changes: 1 addition & 2 deletions llmfoundry/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
flash_attn_fn, scaled_multihead_dot_product_attention)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.fc import *
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
from llmfoundry.models.layers.norm import LPLayerNorm

Expand All @@ -24,7 +24,6 @@
'MPTMLP',
'MPTBlock',
'LPLayerNorm',
'FC_CLASS_REGISTRY',
'SharedEmbedding',
'FFN_CLASS_REGISTRY',
'build_ffn',
Expand Down
21 changes: 11 additions & 10 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from packaging import version
from torch import nn

from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.layer_builders import build_norm
from llmfoundry.models.layers.layer_builders import build_fc, build_norm


def is_flash_v2_installed(v2_version: str = '2.0.0'):
Expand Down Expand Up @@ -406,10 +405,11 @@ def __init__(
'bias': bias,
}
fc_kwargs['device'] = device
self.Wqkv = FC_CLASS_REGISTRY[fc_type](
self.d_model,
self.d_model + 2 * self.kv_n_heads * self.head_dim,
**fc_kwargs,
self.Wqkv = build_fc(
name=fc_type,
in_features=self.d_model,
out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim,
fc_kwargs=fc_kwargs,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = [
Expand Down Expand Up @@ -440,10 +440,11 @@ def __init__(
else:
raise ValueError(f'{attn_impl=} is an invalid setting.')

self.out_proj = FC_CLASS_REGISTRY[fc_type](
self.d_model,
self.d_model,
**fc_kwargs,
self.out_proj = build_fc(
name=fc_type,
in_features=self.d_model,
out_features=self.d_model,
fc_kwargs=fc_kwargs,
)
self.out_proj._is_residual = True

Expand Down
8 changes: 4 additions & 4 deletions llmfoundry/models/layers/fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from torch import nn

FC_CLASS_REGISTRY = {
'torch': nn.Linear,
}
from llmfoundry.layers_registry import fcs

fcs.register('torch', func=nn.Linear)

try:
import transformer_engine.pytorch as te
FC_CLASS_REGISTRY['te'] = te.Linear
fcs.register('te', func=te.Linear)
except:
pass
31 changes: 17 additions & 14 deletions llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Shard

from llmfoundry.models.layers.dmoe import dMoE
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.layer_builders import build_fc

try:
import transformer_engine.pytorch as te
Expand Down Expand Up @@ -52,7 +52,7 @@ def resolve_ffn_act_fn(
config = deepcopy(config)
name = config.pop('name')
if not hasattr(torch.nn.functional, name):
raise ValueError(f'Unrecognised activation function name ({name}).')
raise ValueError(f'Unrecognized activation function name ({name}).')
act = getattr(torch.nn.functional, name)
return partial(act, **config)

Expand Down Expand Up @@ -121,16 +121,18 @@ def __init__(

self.fc_kwargs['device'] = device

self.up_proj = FC_CLASS_REGISTRY[fc_type](
d_model,
ffn_hidden_size,
**self.fc_kwargs,
self.up_proj = build_fc(
name=fc_type,
in_features=d_model,
out_features=ffn_hidden_size,
fc_kwargs=self.fc_kwargs,
)
self.act = act_fn
self.down_proj = FC_CLASS_REGISTRY[fc_type](
ffn_hidden_size,
d_model,
**self.fc_kwargs,
self.down_proj = build_fc(
name=fc_type,
in_features=ffn_hidden_size,
out_features=d_model,
fc_kwargs=self.fc_kwargs,
)
self.down_proj._is_residual = True

Expand Down Expand Up @@ -159,10 +161,11 @@ def __init__(
device=device,
bias=bias,
)
self.gate_proj = FC_CLASS_REGISTRY[fc_type](
d_model,
self.up_proj.out_features,
**self.fc_kwargs,
self.gate_proj = build_fc(
name=fc_type,
in_features=d_model,
out_features=self.up_proj.out_features,
fc_kwargs=self.fc_kwargs,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
22 changes: 20 additions & 2 deletions llmfoundry/models/layers/layer_builders.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import torch

from llmfoundry.layers_registry import norms
from llmfoundry.layers_registry import fcs, norms
from llmfoundry.utils.registry_utils import construct_from_registry


Expand All @@ -23,3 +23,21 @@ def build_norm(
registry=norms,
pre_validation_function=torch.nn.Module,
kwargs=kwargs)


def build_fc(
name: str,
in_features: int,
out_features: int,
fc_kwargs: Dict[str, Any],
):
kwargs = {
'in_features': in_features,
'out_features': out_features,
**fc_kwargs,
}

return construct_from_registry(name=name,
registry=fcs,
pre_validation_function=torch.nn.Module,
kwargs=kwargs)
3 changes: 1 addition & 2 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
# HuggingFace can detect all the needed files to copy into its modules folder.
# Otherwise, certain modules are missing.
# isort: off
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY # type: ignore (see note)
from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note)
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note)
from llmfoundry.models.layers.layer_builders import build_norm, build_fc # type: ignore (see note)
from llmfoundry.models.layers.dmoe import dMoE # type: ignore (see note)
from llmfoundry.models.layers.layer_builders import build_norm # type: ignore (see note)
from llmfoundry.layers_registry import norms # type: ignore (see note)
from llmfoundry.utils.registry_utils import construct_from_registry # type: ignore (see note)

Expand Down
5 changes: 2 additions & 3 deletions llmfoundry/models/utils/param_init_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
from torch import nn
from torch.distributed._tensor import DTensor

from llmfoundry.layers_registry import norms
from llmfoundry.layers_registry import fcs, norms
from llmfoundry.models.layers.dmoe import GLU, MLP
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY

try:
import transformer_engine.pytorch as te
Expand Down Expand Up @@ -182,7 +181,7 @@ def generic_param_init_fn_(
f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}'
)

if isinstance(module, tuple(set(FC_CLASS_REGISTRY.values()))):
if isinstance(module, tuple(set([fcs.get(n) for n in fcs.get_all()]))):
# Linear
if hasattr(module, '_fused'):
fused_init_helper_(module, init_fn_)
Expand Down
3 changes: 2 additions & 1 deletion llmfoundry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from transformers import PreTrainedTokenizerBase

from llmfoundry.interfaces import CallbackWithConfig
from llmfoundry.layers_registry import norms
from llmfoundry.layers_registry import fcs, norms
from llmfoundry.utils.registry_utils import create_registry

_loggers_description = (
Expand Down Expand Up @@ -121,4 +121,5 @@
'metrics',
'dataloaders',
'norms',
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
'fcs',
]
1 change: 1 addition & 0 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_expected_registries_exist():
'metrics',
'models',
'norms',
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
'fcs',
}

assert existing_registries == expected_registry_names
Expand Down
Loading