Skip to content

Commit

Permalink
fc registry
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Apr 5, 2024
1 parent b81897a commit 6842520
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 36 deletions.
8 changes: 8 additions & 0 deletions llmfoundry/layers_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@
entry_points=True,
description=_norm_description)

_fc_description = """The fully connected layers registry is used to register classes that implement fully connected layers."""
fcs = create_registry('llmfoundry',
'fcs',
generic_type=Type[torch.nn.Module],
entry_points=True,
description=_fc_description)

__all__ = [
'norms',
'fcs',
]
2 changes: 0 additions & 2 deletions llmfoundry/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
flash_attn_fn, scaled_multihead_dot_product_attention)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
from llmfoundry.models.layers.norm import LPLayerNorm

Expand All @@ -24,7 +23,6 @@
'MPTMLP',
'MPTBlock',
'LPLayerNorm',
'FC_CLASS_REGISTRY',
'SharedEmbedding',
'FFN_CLASS_REGISTRY',
'build_ffn',
Expand Down
21 changes: 11 additions & 10 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from packaging import version
from torch import nn

from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.layer_builders import build_norm
from llmfoundry.models.layers.layer_builders import build_fc, build_norm


def is_flash_v2_installed(v2_version: str = '2.0.0'):
Expand Down Expand Up @@ -406,10 +405,11 @@ def __init__(
'bias': bias,
}
fc_kwargs['device'] = device
self.Wqkv = FC_CLASS_REGISTRY[fc_type](
self.d_model,
self.d_model + 2 * self.kv_n_heads * self.head_dim,
**fc_kwargs,
self.Wqkv = build_fc(
name=fc_type,
in_features=self.d_model,
out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim,
fc_kwargs=fc_kwargs,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = [
Expand Down Expand Up @@ -440,10 +440,11 @@ def __init__(
else:
raise ValueError(f'{attn_impl=} is an invalid setting.')

self.out_proj = FC_CLASS_REGISTRY[fc_type](
self.d_model,
self.d_model,
**fc_kwargs,
self.out_proj = build_fc(
name=fc_type,
in_features=self.d_model,
out_features=self.d_model,
fc_kwargs=fc_kwargs,
)
self.out_proj._is_residual = True

Expand Down
8 changes: 4 additions & 4 deletions llmfoundry/models/layers/fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from torch import nn

FC_CLASS_REGISTRY = {
'torch': nn.Linear,
}
from llmfoundry.layers_registry import fcs

fcs.register('torch', func=nn.Linear)

try:
import transformer_engine.pytorch as te
FC_CLASS_REGISTRY['te'] = te.Linear
fcs.register('te', func=te.Linear)
except:
pass
29 changes: 16 additions & 13 deletions llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
import torch.nn as nn

from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.layer_builders import build_fc

try:
import transformer_engine.pytorch as te
Expand Down Expand Up @@ -100,16 +100,18 @@ def __init__(

self.fc_kwargs['device'] = device

self.up_proj = FC_CLASS_REGISTRY[fc_type](
d_model,
ffn_hidden_size,
**self.fc_kwargs,
self.up_proj = build_fc(
name=fc_type,
in_features=d_model,
out_features=ffn_hidden_size,
fc_kwargs=self.fc_kwargs,
)
self.act = act_fn
self.down_proj = FC_CLASS_REGISTRY[fc_type](
ffn_hidden_size,
d_model,
**self.fc_kwargs,
self.down_proj = build_fc(
name=fc_type,
in_features=ffn_hidden_size,
out_features=d_model,
fc_kwargs=self.fc_kwargs,
)
self.down_proj._is_residual = True

Expand Down Expand Up @@ -138,10 +140,11 @@ def __init__(
device=device,
bias=bias,
)
self.gate_proj = FC_CLASS_REGISTRY[fc_type](
d_model,
self.up_proj.out_features,
**self.fc_kwargs,
self.gate_proj = build_fc(
name=fc_type,
in_features=d_model,
out_features=self.up_proj.out_features,
fc_kwargs=self.fc_kwargs,
)

@torch.compile
Expand Down
22 changes: 20 additions & 2 deletions llmfoundry/models/layers/layer_builders.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import torch

from llmfoundry.layers_registry import norms
from llmfoundry.layers_registry import fcs, norms
from llmfoundry.utils.registry_utils import construct_from_registry


Expand All @@ -23,3 +23,21 @@ def build_norm(
registry=norms,
pre_validation_function=torch.nn.Module,
kwargs=kwargs)


def build_fc(
name: str,
in_features: int,
out_features: int,
fc_kwargs: Dict[str, Any],
):
kwargs = {
'in_features': in_features,
'out_features': out_features,
**fc_kwargs,
}

return construct_from_registry(name=name,
registry=fcs,
pre_validation_function=torch.nn.Module,
kwargs=kwargs)
3 changes: 1 addition & 2 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@
# HuggingFace can detect all the needed files to copy into its modules folder.
# Otherwise, certain modules are missing.
# isort: off
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY # type: ignore (see note)
from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note)
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note)
from llmfoundry.models.layers.layer_builders import build_norm # type: ignore (see note)
from llmfoundry.models.layers.layer_builders import build_norm, build_fc # type: ignore (see note)
from llmfoundry.layers_registry import norms # type: ignore (see note)
from llmfoundry.utils.registry_utils import construct_from_registry # type: ignore (see note)

Expand Down
5 changes: 2 additions & 3 deletions llmfoundry/models/utils/param_init_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
import torch
from torch import nn

from llmfoundry.layers_registry import norms
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.layers_registry import fcs, norms

try:
import transformer_engine.pytorch as te
Expand Down Expand Up @@ -87,7 +86,7 @@ def generic_param_init_fn_(
f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}'
)

if isinstance(module, tuple(set(FC_CLASS_REGISTRY.values()))):
if isinstance(module, tuple(set([fcs.get(n) for n in fcs.get_all()]))):
# Linear
if hasattr(module, '_fused'):
fused_init_helper_(module, init_fn_)
Expand Down

0 comments on commit 6842520

Please sign in to comment.