Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Param init registry #1096

Merged
merged 22 commits into from
Apr 13, 2024
22 changes: 22 additions & 0 deletions llmfoundry/layers_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,30 @@
entry_points=True,
description=_attention_implementations_description)

_param_init_fns_description = (
'The param_init_fns registry is used to register functions that initialize parameters.'
+
'These will be called on a module to initialize its parameters. See param_init_fns.py for examples.'
)
param_init_fns = create_registry('llmfoundry',
'param_init_fns',
generic_type=Callable[..., None],
entry_points=True,
description=_param_init_fns_description)

_module_init_fns_description = """The module_init_fns registry is used to register functions that initialize specific modules.
These functions should return True if they initialize the module, and False otherwise. This allows them to be called without knowing their contents.
They should take in the module, init_div_is_residual, and div_is_residual arguments."""
module_init_fns = create_registry('llmfoundry',
'module_init_fns',
generic_type=Callable[..., bool],
entry_points=True,
description=_module_init_fns_description)

__all__ = [
'norms',
'param_init_fns',
'module_init_fns',
'ffns',
'ffns_with_norm',
'ffns_with_megablocks',
Expand Down
7 changes: 3 additions & 4 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from transformers.models.llama.modeling_llama import \
LlamaRotaryEmbedding as HFRotaryEmbedding

from llmfoundry.layers_registry import norms
from llmfoundry.layers_registry import norms, param_init_fns
from llmfoundry.models.layers.attention import (attn_bias_shape,
build_attn_bias, gen_slopes)
from llmfoundry.models.layers.blocks import MPTBlock
Expand All @@ -62,7 +62,6 @@
init_empty_weights # type: ignore (see note)
from llmfoundry.models.utils.param_init_fns import (
generic_param_init_fn_, # type: ignore (see note)
MODEL_INIT_REGISTRY,
)
from llmfoundry.models.layers.ffn import resolve_ffn_act_fn # type: ignore (see note)

Expand Down Expand Up @@ -678,7 +677,7 @@ def forward(
# Param Initialization, needed for device='meta' fast initialization
def param_init_fn(self, module: nn.Module) -> None:
init_fn_name = self.config.init_config['name']
MODEL_INIT_REGISTRY[init_fn_name](
param_init_fns.get(init_fn_name)(
module=module,
n_layers=self.config.n_layers,
d_model=self.config.d_model,
Expand Down Expand Up @@ -838,7 +837,7 @@ def forward(
# Param Initialization, needed for device='meta' fast initialization
def param_init_fn(self, module: nn.Module) -> None:
init_fn_name = self.config.init_config['name']
MODEL_INIT_REGISTRY[init_fn_name](
param_init_fns.get(init_fn_name)(
module=module,
n_layers=self.config.n_layers,
d_model=self.config.d_model,
Expand Down
4 changes: 1 addition & 3 deletions llmfoundry/models/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
init_on_device)
from llmfoundry.models.utils.mpt_param_count import (mpt_get_active_params,
mpt_get_total_params)
from llmfoundry.models.utils.param_init_fns import (MODEL_INIT_REGISTRY,
generic_param_init_fn_)
from llmfoundry.models.utils.param_init_fns import generic_param_init_fn_

__all__ = [
'init_empty_weights',
'init_on_device',
'generic_param_init_fn_',
'MODEL_INIT_REGISTRY',
'config_moe_args',
'mpt_get_active_params',
'mpt_get_total_params',
Expand Down
Loading
Loading