Skip to content

Commit

Permalink
Param init registry (#1096)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Apr 13, 2024
1 parent cb0de4f commit 676ad7f
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 64 deletions.
22 changes: 22 additions & 0 deletions llmfoundry/layers_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,30 @@
entry_points=True,
description=_attention_implementations_description)

_param_init_fns_description = (
'The param_init_fns registry is used to register functions that initialize parameters.'
+
'These will be called on a module to initialize its parameters. See param_init_fns.py for examples.'
)
param_init_fns = create_registry('llmfoundry',
'param_init_fns',
generic_type=Callable[..., None],
entry_points=True,
description=_param_init_fns_description)

_module_init_fns_description = """The module_init_fns registry is used to register functions that initialize specific modules.
These functions should return True if they initialize the module, and False otherwise. This allows them to be called without knowing their contents.
They should take in the module, init_div_is_residual, and div_is_residual arguments."""
module_init_fns = create_registry('llmfoundry',
'module_init_fns',
generic_type=Callable[..., bool],
entry_points=True,
description=_module_init_fns_description)

__all__ = [
'norms',
'param_init_fns',
'module_init_fns',
'ffns',
'ffns_with_norm',
'ffns_with_megablocks',
Expand Down
7 changes: 3 additions & 4 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from transformers.models.llama.modeling_llama import \
LlamaRotaryEmbedding as HFRotaryEmbedding

from llmfoundry.layers_registry import norms
from llmfoundry.layers_registry import norms, param_init_fns
from llmfoundry.models.layers.attention import (attn_bias_shape,
build_attn_bias, gen_slopes)
from llmfoundry.models.layers.blocks import MPTBlock
Expand All @@ -62,7 +62,6 @@
init_empty_weights # type: ignore (see note)
from llmfoundry.models.utils.param_init_fns import (
generic_param_init_fn_, # type: ignore (see note)
MODEL_INIT_REGISTRY,
)
from llmfoundry.models.layers.ffn import resolve_ffn_act_fn # type: ignore (see note)

Expand Down Expand Up @@ -678,7 +677,7 @@ def forward(
# Param Initialization, needed for device='meta' fast initialization
def param_init_fn(self, module: nn.Module) -> None:
init_fn_name = self.config.init_config['name']
MODEL_INIT_REGISTRY[init_fn_name](
param_init_fns.get(init_fn_name)(
module=module,
n_layers=self.config.n_layers,
d_model=self.config.d_model,
Expand Down Expand Up @@ -838,7 +837,7 @@ def forward(
# Param Initialization, needed for device='meta' fast initialization
def param_init_fn(self, module: nn.Module) -> None:
init_fn_name = self.config.init_config['name']
MODEL_INIT_REGISTRY[init_fn_name](
param_init_fns.get(init_fn_name)(
module=module,
n_layers=self.config.n_layers,
d_model=self.config.d_model,
Expand Down
4 changes: 1 addition & 3 deletions llmfoundry/models/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
init_on_device)
from llmfoundry.models.utils.mpt_param_count import (mpt_get_active_params,
mpt_get_total_params)
from llmfoundry.models.utils.param_init_fns import (MODEL_INIT_REGISTRY,
generic_param_init_fn_)
from llmfoundry.models.utils.param_init_fns import generic_param_init_fn_

__all__ = [
'init_empty_weights',
'init_on_device',
'generic_param_init_fn_',
'MODEL_INIT_REGISTRY',
'config_moe_args',
'mpt_get_active_params',
'mpt_get_total_params',
Expand Down
Loading

0 comments on commit 676ad7f

Please sign in to comment.