diff --git a/llmfoundry/layers_registry.py b/llmfoundry/layers_registry.py index 19b3b1c5cf..24593144aa 100644 --- a/llmfoundry/layers_registry.py +++ b/llmfoundry/layers_registry.py @@ -73,8 +73,30 @@ entry_points=True, description=_attention_implementations_description) +_param_init_fns_description = ( + 'The param_init_fns registry is used to register functions that initialize parameters.' + + + 'These will be called on a module to initialize its parameters. See param_init_fns.py for examples.' +) +param_init_fns = create_registry('llmfoundry', + 'param_init_fns', + generic_type=Callable[..., None], + entry_points=True, + description=_param_init_fns_description) + +_module_init_fns_description = """The module_init_fns registry is used to register functions that initialize specific modules. +These functions should return True if they initialize the module, and False otherwise. This allows them to be called without knowing their contents. +They should take in the module, init_div_is_residual, and div_is_residual arguments.""" +module_init_fns = create_registry('llmfoundry', + 'module_init_fns', + generic_type=Callable[..., bool], + entry_points=True, + description=_module_init_fns_description) + __all__ = [ 'norms', + 'param_init_fns', + 'module_init_fns', 'ffns', 'ffns_with_norm', 'ffns_with_megablocks', diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 124ab3db3e..1ef62a3b19 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -43,7 +43,7 @@ from transformers.models.llama.modeling_llama import \ LlamaRotaryEmbedding as HFRotaryEmbedding -from llmfoundry.layers_registry import norms +from llmfoundry.layers_registry import norms, param_init_fns from llmfoundry.models.layers.attention import (attn_bias_shape, build_attn_bias, gen_slopes) from llmfoundry.models.layers.blocks import MPTBlock @@ -62,7 +62,6 @@ init_empty_weights # type: ignore (see note) from llmfoundry.models.utils.param_init_fns import ( generic_param_init_fn_, # type: ignore (see note) - MODEL_INIT_REGISTRY, ) from llmfoundry.models.layers.ffn import resolve_ffn_act_fn # type: ignore (see note) @@ -678,7 +677,7 @@ def forward( # Param Initialization, needed for device='meta' fast initialization def param_init_fn(self, module: nn.Module) -> None: init_fn_name = self.config.init_config['name'] - MODEL_INIT_REGISTRY[init_fn_name]( + param_init_fns.get(init_fn_name)( module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, @@ -838,7 +837,7 @@ def forward( # Param Initialization, needed for device='meta' fast initialization def param_init_fn(self, module: nn.Module) -> None: init_fn_name = self.config.init_config['name'] - MODEL_INIT_REGISTRY[init_fn_name]( + param_init_fns.get(init_fn_name)( module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, diff --git a/llmfoundry/models/utils/__init__.py b/llmfoundry/models/utils/__init__.py index 41313b8729..ca5fa4b935 100644 --- a/llmfoundry/models/utils/__init__.py +++ b/llmfoundry/models/utils/__init__.py @@ -6,14 +6,12 @@ init_on_device) from llmfoundry.models.utils.mpt_param_count import (mpt_get_active_params, mpt_get_total_params) -from llmfoundry.models.utils.param_init_fns import (MODEL_INIT_REGISTRY, - generic_param_init_fn_) +from llmfoundry.models.utils.param_init_fns import generic_param_init_fn_ __all__ = [ 'init_empty_weights', 'init_on_device', 'generic_param_init_fn_', - 'MODEL_INIT_REGISTRY', 'config_moe_args', 'mpt_get_active_params', 'mpt_get_total_params', diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index bd409dee36..e64cde5e96 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -12,7 +12,8 @@ from torch import nn from torch.distributed._tensor import DTensor -from llmfoundry.layers_registry import fcs, norms +from llmfoundry.layers_registry import (fcs, module_init_fns, norms, + param_init_fns) from llmfoundry.models.layers.dmoe import GLU, MLP try: @@ -53,7 +54,7 @@ def fused_init_helper_( Args: module (nn.Module): The module to initialize. init_fn_ (Callable): Initialization method. - name_param (str): Name of parameter to initalize within the module. + name_param (str): Name of parameter to initialize within the module. """ _fused = getattr(module, '_fused', None) if _fused is None: @@ -89,7 +90,7 @@ def stacked_init_helper_( init_fn_: Callable, name_param: str = 'weight', ): - """Initializes parameters stacked along a new dimention. + """Initializes parameters stacked along a new dimension. Parameter initialization is often based on the parameters shape. If a layer is stacked, initialization should be based on the shapes of the original tensor instead of the @@ -99,7 +100,7 @@ def stacked_init_helper_( Args: module (nn.Module): The module to initialize. init_fn_ (Callable): Initialization method. - name_param (str): Name of parameter to initalize within the module. + name_param (str): Name of parameter to initialize within the module. """ stack_dim = getattr(module, '_stack_dim', None) if stack_dim is None: @@ -113,12 +114,12 @@ def stacked_param_init_helper( init_fn_: Callable, stack_dim: int, ): - """Initialize parameters stacked along a new dimention. + """Initialize parameters stacked along a new dimension. Args: param (torch.Tensor): Tensor to initialize. init_fn_ (Callable): Initialization method. - stack_dim (int): Dimention along with parameters are stacked + stack_dim (int): Dimension along with parameters are stacked """ p_ndims = param.ndim @@ -147,39 +148,14 @@ def _flip_fan_mode(init_fn_: Callable): return _init_fn_ -def generic_param_init_fn_( +def fc_init( module: nn.Module, init_fn_: Callable, - n_layers: int, - d_model: Optional[int] = None, - init_div_is_residual: Union[int, float, str, bool] = True, - emb_init_std: Optional[float] = None, - emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, + init_div_is_residual: Union[int, float, str, bool], + div_is_residual: Optional[float], **kwargs: Any, -) -> None: - del kwargs # unused, just to capture any extra args from the config - # enable user to divide _is_residual weights by - - # a value which defaults to math.sqrt(2 * cfg.n_layers) - init_div_is_residual = init_div_is_residual - - if init_div_is_residual is False: - # not used, for pyright - div_is_residual = 1.0 - elif init_div_is_residual is True: - div_is_residual = math.sqrt(2 * n_layers) - elif isinstance(init_div_is_residual, float) or isinstance( - init_div_is_residual, int): - div_is_residual = init_div_is_residual - elif init_div_is_residual.isnumeric(): - # do not trust YAML parsing to always convert numbers to numbers - div_is_residual = float(init_div_is_residual) - else: - # not used, for pyright - div_is_residual = 1.0 - raise ValueError( - f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}' - ) +) -> bool: + del kwargs # unused, just to capture any extra args if isinstance(module, tuple(set([fcs.get(n) for n in fcs.get_all()]))): # Linear @@ -195,8 +171,21 @@ def generic_param_init_fn_( module, '_is_residual', False): with torch.no_grad(): module.weight.div_(div_is_residual) # type: ignore + return True - elif isinstance(module, nn.Embedding): + return False + + +def embedding_init( + module: nn.Module, + init_fn_: Callable, + emb_init_std: Optional[float], + emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]], + **kwargs: Any, +) -> bool: + del kwargs # unused, just to capture any extra args + + if isinstance(module, nn.Embedding): # Embedding if emb_init_std is not None: std = emb_init_std @@ -223,8 +212,19 @@ def generic_param_init_fn_( emb_init_fn_(module.weight) - elif isinstance(module, - tuple(set([norms.get(name) for name in norms.get_all()]))): + return True + + return False + + +def norm_init( + module: nn.Module, + **kwargs: Any, +) -> bool: + del kwargs # unused, just to capture any extra args + + if isinstance(module, + tuple(set([norms.get(name) for name in norms.get_all()]))): # Norm if hasattr(module, 'weight') and isinstance(module.weight, torch.Tensor): @@ -232,7 +232,22 @@ def generic_param_init_fn_( if hasattr(module, 'bias') and isinstance(module.bias, torch.Tensor): torch.nn.init.zeros_(module.bias) - elif isinstance(module, nn.MultiheadAttention): + return True + + return False + + +def multihead_attention_init( + module: nn.Module, + init_fn_: Callable, + d_model: Optional[int], + init_div_is_residual: Union[int, float, str, bool], + div_is_residual: float, + **kwargs: Any, +) -> bool: + del kwargs # unused, just to capture any extra args + + if isinstance(module, nn.MultiheadAttention): # torch's MultiheadAttention if module._qkv_same_embed_dim: assert module.in_proj_weight is not None @@ -267,7 +282,19 @@ def generic_param_init_fn_( if module.out_proj.bias is not None: torch.nn.init.zeros_(module.out_proj.bias) - elif te is not None and isinstance(module, te.LayerNormMLP): + return True + + return False + + +def te_layernorm_mlp_init( + module: nn.Module, + init_fn_: Callable, + **kwargs: Any, +) -> bool: + del kwargs # unused, just to capture any extra args + + if te is not None and isinstance(module, te.LayerNormMLP): if isinstance(module.layer_norm_weight, torch.Tensor): torch.nn.init.ones_(module.layer_norm_weight) if isinstance(module.layer_norm_bias, torch.Tensor): @@ -285,7 +312,19 @@ def generic_param_init_fn_( with torch.no_grad(): module.fc2_weight.div_(div_is_residual) # type: ignore - elif megablocks is not None and isinstance(module, ( + return True + + return False + + +def moe_init( + module: nn.Module, + init_fn_: Callable, + init_div_is_residual: Union[int, float, str, bool], + div_is_residual: float, + **kwargs: Any, +) -> bool: + if megablocks is not None and isinstance(module, ( megablocks.layers.moe.MoE, megablocks.layers.dmoe.dMoE, megablocks.layers.moe.ParallelMLP, @@ -294,32 +333,96 @@ def generic_param_init_fn_( if hasattr(module, 'bias') and module.bias is not None: # Initialize bias to 0 torch.nn.init.zeros_(module.bias) # type: ignore + return True elif megablocks is not None and isinstance(module, megablocks.layers.glu.SparseGLU): _megablocks_sparse_glu_generic_param_init_fn_( module, init_fn_, bool(init_div_is_residual), div_is_residual) + return True elif megablocks is not None and isinstance(module, megablocks.layers.mlp.SparseMLP): _megablocks_sparse_mlp_generic_param_init_fn_( module, init_fn_, bool(init_div_is_residual), div_is_residual) + return True elif megablocks is not None and isinstance(module, megablocks.layers.mlp.MLP): _megablocks_mlp_generic_param_init_fn_(module, init_fn_, bool(init_div_is_residual), div_is_residual) + return True elif isinstance(module, GLU): init_fn_(module.w1) init_fn_(module.v1) init_fn_(module.w2) + return True elif isinstance(module, MLP): init_fn_(module.w1) init_fn_(module.w2) + return True + + return False + + +def generic_param_init_fn_( + module: nn.Module, + init_fn_: Callable, + n_layers: int, + d_model: Optional[int] = None, + init_div_is_residual: Union[int, float, str, bool] = True, + emb_init_std: Optional[float] = None, + emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, + **kwargs: Any, +) -> None: + del kwargs # unused, just to capture any extra args from the config + # enable user to divide _is_residual weights by + + # a value which defaults to math.sqrt(2 * cfg.n_layers) + init_div_is_residual = init_div_is_residual + + if init_div_is_residual is False: + # not used, for pyright + div_is_residual = 1.0 + elif init_div_is_residual is True: + div_is_residual = math.sqrt(2 * n_layers) + elif isinstance(init_div_is_residual, float) or isinstance( + init_div_is_residual, int): + div_is_residual = init_div_is_residual + elif init_div_is_residual.isnumeric(): + # do not trust YAML parsing to always convert numbers to numbers + div_is_residual = float(init_div_is_residual) else: + # not used, for pyright + div_is_residual = 1.0 + raise ValueError( + f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}' + ) + + all_module_init_fns = [ + module_init_fns.get(name) for name in module_init_fns.get_all() + ] + did_init = False + for module_init_fn in all_module_init_fns: + did_init = module_init_fn( + module=module, + init_fn_=init_fn_, + d_model=d_model, + init_div_is_residual=init_div_is_residual, + div_is_residual=div_is_residual, + emb_init_std=emb_init_std, + emb_init_uniform_lim=emb_init_uniform_lim, + ) + + if did_init: + break + + if not did_init: for _ in module.parameters(recurse=False): # raise error if uninitialized module has any parameters raise NotImplementedError( - f'{module.__class__.__name__} parameters are not initialized by param_init_fn.' - ) + f'{module.__class__.__name__} parameters are not initialized by any of the registered module_init_fns. ' + + + 'Please add an appropriate module_init_fn to the registry. Currently registered module_init_fns are: ' + + ', '.join(module_init_fns.get_all())) def _megablocks_sparse_mlp_generic_param_init_fn_( @@ -725,13 +828,18 @@ def xavier_normal_param_init_fn_( ) -MODEL_INIT_REGISTRY = { - 'default_': torch_default_param_init_fn_, - 'baseline_': baseline_param_init_fn_, - 'kaiming_uniform_': kaiming_uniform_param_init_fn_, - 'kaiming_normal_': kaiming_normal_param_init_fn_, - 'neox_init_': neox_param_init_fn_, - 'small_init_': small_param_init_fn_, - 'xavier_uniform_': xavier_uniform_param_init_fn_, - 'xavier_normal_': xavier_normal_param_init_fn_, -} +param_init_fns.register('default_', func=torch_default_param_init_fn_) +param_init_fns.register('baseline_', func=baseline_param_init_fn_) +param_init_fns.register('kaiming_uniform_', func=kaiming_uniform_param_init_fn_) +param_init_fns.register('kaiming_normal_', func=kaiming_normal_param_init_fn_) +param_init_fns.register('neox_init_', func=neox_param_init_fn_) +param_init_fns.register('small_init_', func=small_param_init_fn_) +param_init_fns.register('xavier_uniform_', func=xavier_uniform_param_init_fn_) +param_init_fns.register('xavier_normal_', func=xavier_normal_param_init_fn_) + +module_init_fns.register('fc', func=fc_init) +module_init_fns.register('embedding', func=embedding_init) +module_init_fns.register('norm', func=norm_init) +module_init_fns.register('multihead_attention', func=multihead_attention_init) +module_init_fns.register('te_layernorm_mlp', func=te_layernorm_mlp_init) +module_init_fns.register('moe', func=moe_init) diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index 64ed5d7b65..6e1824ea08 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -15,7 +15,7 @@ from llmfoundry.layers_registry import (attention_classes, attention_implementations, fcs, ffns, ffns_with_megablocks, ffns_with_norm, - norms) + module_init_fns, norms, param_init_fns) from llmfoundry.utils.registry_utils import create_registry _loggers_description = ( @@ -133,6 +133,8 @@ 'metrics', 'dataloaders', 'norms', + 'param_init_fns', + 'module_init_fns', 'ffns', 'ffns_with_norm', 'ffns_with_megablocks', diff --git a/tests/models/utils/test_param_init_fns.py b/tests/models/utils/test_param_init_fns.py index 6be2c5ca42..0efc245602 100644 --- a/tests/models/utils/test_param_init_fns.py +++ b/tests/models/utils/test_param_init_fns.py @@ -12,7 +12,8 @@ from omegaconf import OmegaConf as om from torch import nn -from llmfoundry.models.utils import MODEL_INIT_REGISTRY, generic_param_init_fn_ +from llmfoundry.layers_registry import param_init_fns +from llmfoundry.models.utils import generic_param_init_fn_ class MLP(nn.Module): @@ -150,7 +151,7 @@ def test_emb_init(emb_init_cfg: Optional[Tuple[str, Union[int, List[int]]]]): bias=True)), ])) - model.apply(partial(MODEL_INIT_REGISTRY['kaiming_normal_'], **dict_cfg)) + model.apply(partial(param_init_fns.get('kaiming_normal_'), **dict_cfg)) assert isinstance(model.emb, torch.nn.Embedding) diff --git a/tests/test_registry.py b/tests/test_registry.py index aaba89c43d..d7a1fc7dfe 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -31,6 +31,8 @@ def test_expected_registries_exist(): 'metrics', 'models', 'norms', + 'param_init_fns', + 'module_init_fns', 'ffns', 'ffns_with_norm', 'ffns_with_megablocks',