From 7136162bf5fd6df08e6a54e950686570086022fe Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 3 Apr 2024 22:36:00 -0700 Subject: [PATCH 01/17] param init registry --- llmfoundry/layers_registry.py | 20 ++- llmfoundry/models/mpt/modeling_mpt.py | 7 +- llmfoundry/models/utils/__init__.py | 4 +- llmfoundry/models/utils/param_init_fns.py | 186 ++++++++++++++++------ tests/models/utils/test_param_init_fns.py | 5 +- 5 files changed, 163 insertions(+), 59 deletions(-) diff --git a/llmfoundry/layers_registry.py b/llmfoundry/layers_registry.py index 9c7dabe128..855cf351fa 100644 --- a/llmfoundry/layers_registry.py +++ b/llmfoundry/layers_registry.py @@ -1,7 +1,7 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import Type +from typing import Callable, Type import torch @@ -15,6 +15,24 @@ entry_points=True, description=_norm_description) +_param_init_fns_description = """The param_init_fns registry is used to register functions that initialize parameters.""" +param_init_fns = create_registry('llmfoundry', + 'param_init_fns', + generic_type=Callable[..., None], + entry_points=True, + description=_param_init_fns_description) + +_module_init_fns_description = """The module_init_fns registry is used to register functions that initialize specific modules. +These functions should return True if they initialize the module, and False otherwise. This allows them to be called without knowing their contents. +They should take in the module, init_div_is_residual, and div_is_residual arguments.""" +module_init_fns = create_registry('llmfoundry', + 'module_init_fns', + generic_type=Callable[..., bool], + entry_points=True, + description=_module_init_fns_description) + __all__ = [ 'norms', + 'param_init_fns', + 'module_init_fns', ] diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index d54b797269..7288085267 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -41,7 +41,7 @@ from transformers.models.llama.modeling_llama import \ LlamaRotaryEmbedding as HFRotaryEmbedding -from llmfoundry.layers_registry import norms +from llmfoundry.layers_registry import norms, param_init_fns from llmfoundry.models.layers.attention import (attn_bias_shape, build_attn_bias, gen_slopes) from llmfoundry.models.layers.blocks import MPTBlock @@ -58,7 +58,6 @@ init_empty_weights # type: ignore (see note) from llmfoundry.models.utils.param_init_fns import ( generic_param_init_fn_, # type: ignore (see note) - MODEL_INIT_REGISTRY, ) from llmfoundry.models.utils.act_ckpt import (pass_on_block_idx, @@ -660,7 +659,7 @@ def forward( # Param Initialization, needed for device='meta' fast initialization def param_init_fn(self, module: nn.Module) -> None: init_fn_name = self.config.init_config['name'] - MODEL_INIT_REGISTRY[init_fn_name]( + param_init_fns.get(init_fn_name)( module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, @@ -820,7 +819,7 @@ def forward( # Param Initialization, needed for device='meta' fast initialization def param_init_fn(self, module: nn.Module) -> None: init_fn_name = self.config.init_config['name'] - MODEL_INIT_REGISTRY[init_fn_name]( + param_init_fns.get(init_fn_name)( module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, diff --git a/llmfoundry/models/utils/__init__.py b/llmfoundry/models/utils/__init__.py index 7c808ff449..9e46099d3d 100644 --- a/llmfoundry/models/utils/__init__.py +++ b/llmfoundry/models/utils/__init__.py @@ -3,12 +3,10 @@ from llmfoundry.models.utils.meta_init_context import (init_empty_weights, init_on_device) -from llmfoundry.models.utils.param_init_fns import (MODEL_INIT_REGISTRY, - generic_param_init_fn_) +from llmfoundry.models.utils.param_init_fns import generic_param_init_fn_ __all__ = [ 'init_empty_weights', 'init_on_device', 'generic_param_init_fn_', - 'MODEL_INIT_REGISTRY', ] diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index 35dc88a408..88618be7c9 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -10,7 +10,7 @@ import torch from torch import nn -from llmfoundry.layers_registry import norms +from llmfoundry.layers_registry import module_init_fns, norms, param_init_fns from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY try: @@ -53,42 +53,16 @@ def fused_init_helper_(module: nn.Module, init_fn_: Callable) -> None: init_fn_(module.weight[slice_indices]) -def generic_param_init_fn_( +def fc_init( module: nn.Module, init_fn_: Callable, - n_layers: int, - d_model: Optional[int] = None, - init_div_is_residual: Union[int, float, str, bool] = True, - emb_init_std: Optional[float] = None, - emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, + init_div_is_residual: Union[int, float, str, bool], + div_is_residual: Optional[float], **kwargs: Any, -) -> None: - del kwargs # unused, just to capture any extra args from the config - # enable user to divide _is_residual weights by - - # a value which defaults to math.sqrt(2 * cfg.n_layers) - init_div_is_residual = init_div_is_residual - - if init_div_is_residual is False: - # not used, for pyright - div_is_residual = 1.0 - elif init_div_is_residual is True: - div_is_residual = math.sqrt(2 * n_layers) - elif isinstance(init_div_is_residual, float) or isinstance( - init_div_is_residual, int): - div_is_residual = init_div_is_residual - elif init_div_is_residual.isnumeric(): - # do not trust YAML parsing to always convert numbers to numbers - div_is_residual = float(init_div_is_residual) - else: - # not used, for pyright - div_is_residual = 1.0 - raise ValueError( - f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}' - ) +) -> bool: + del kwargs # unused, just to capture any extra args if isinstance(module, tuple(set(FC_CLASS_REGISTRY.values()))): - # Linear if hasattr(module, '_fused'): fused_init_helper_(module, init_fn_) else: @@ -101,8 +75,21 @@ def generic_param_init_fn_( module, '_is_residual', False): with torch.no_grad(): module.weight.div_(div_is_residual) # type: ignore + return True + + return False + + +def embedding_init( + module: nn.Module, + init_fn_: Callable, + emb_init_std: Optional[float], + emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]], + **kwargs: Any, +) -> bool: + del kwargs # unused, just to capture any extra args - elif isinstance(module, nn.Embedding): + if isinstance(module, nn.Embedding): # Embedding if emb_init_std is not None: std = emb_init_std @@ -129,8 +116,19 @@ def generic_param_init_fn_( emb_init_fn_(module.weight) - elif isinstance(module, - tuple(set([norms.get(name) for name in norms.get_all()]))): + return True + + return False + + +def norm_init( + module: nn.Module, + **kwargs: Any, +) -> bool: + del kwargs # unused, just to capture any extra args + + if isinstance(module, + tuple(set([norms.get(name) for name in norms.get_all()]))): # Norm if hasattr(module, 'weight') and isinstance(module.weight, torch.Tensor): @@ -138,7 +136,22 @@ def generic_param_init_fn_( if hasattr(module, 'bias') and isinstance(module.bias, torch.Tensor): torch.nn.init.zeros_(module.bias) - elif isinstance(module, nn.MultiheadAttention): + return True + + return False + + +def multihead_attention_init( + module: nn.Module, + init_fn_: Callable, + d_model: Optional[int], + init_div_is_residual: Union[int, float, str, bool], + div_is_residual: float, + **kwargs: Any, +) -> bool: + del kwargs # unused, just to capture any extra args + + if isinstance(module, nn.MultiheadAttention): # torch's MultiheadAttention if module._qkv_same_embed_dim: assert module.in_proj_weight is not None @@ -173,7 +186,19 @@ def generic_param_init_fn_( if module.out_proj.bias is not None: torch.nn.init.zeros_(module.out_proj.bias) - elif te is not None and isinstance(module, te.LayerNormMLP): + return True + + return False + + +def te_layernorm_mlp_init( + module: nn.Module, + init_fn_: Callable, + **kwargs: Any, +) -> bool: + del kwargs # unused, just to capture any extra args + + if te is not None and isinstance(module, te.LayerNormMLP): if isinstance(module.layer_norm_weight, torch.Tensor): torch.nn.init.ones_(module.layer_norm_weight) if isinstance(module.layer_norm_bias, torch.Tensor): @@ -191,12 +216,71 @@ def generic_param_init_fn_( with torch.no_grad(): module.fc2_weight.div_(div_is_residual) # type: ignore + return True + + return False + + +def generic_param_init_fn_( + module: nn.Module, + init_fn_: Callable, + n_layers: int, + d_model: Optional[int] = None, + init_div_is_residual: Union[int, float, str, bool] = True, + emb_init_std: Optional[float] = None, + emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, + **kwargs: Any, +) -> None: + del kwargs # unused, just to capture any extra args from the config + # enable user to divide _is_residual weights by + + # a value which defaults to math.sqrt(2 * cfg.n_layers) + init_div_is_residual = init_div_is_residual + + if init_div_is_residual is False: + # not used, for pyright + div_is_residual = 1.0 + elif init_div_is_residual is True: + div_is_residual = math.sqrt(2 * n_layers) + elif isinstance(init_div_is_residual, float) or isinstance( + init_div_is_residual, int): + div_is_residual = init_div_is_residual + elif init_div_is_residual.isnumeric(): + # do not trust YAML parsing to always convert numbers to numbers + div_is_residual = float(init_div_is_residual) else: + # not used, for pyright + div_is_residual = 1.0 + raise ValueError( + f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}' + ) + + all_module_init_fns = [ + module_init_fns.get(name) for name in module_init_fns.get_all() + ] + did_init = False + for module_init_fn in all_module_init_fns: + did_init = module_init_fn( + module=module, + init_fn_=init_fn_, + d_model=d_model, + init_div_is_residual=init_div_is_residual, + div_is_residual=div_is_residual, + emb_init_std=emb_init_std, + emb_init_uniform_lim=emb_init_uniform_lim, + ) + + if did_init: + break + + if not did_init: for _ in module.parameters(recurse=False): # raise error if uninitialized module has any parameters raise NotImplementedError( - f'{module.__class__.__name__} parameters are not initialized by param_init_fn.' - ) + f'{module.__class__.__name__} parameters are not initialized by any of the registered module_init_fns. ' + + + 'Please add an appropriate module_init_fn to the registry. Currently registered module_init_fns are: ' + + ', '.join(module_init_fns.get_all())) def _normal_init_(std: float, mean: float = 0.0) -> Callable: @@ -412,13 +496,17 @@ def xavier_normal_param_init_fn_( ) -MODEL_INIT_REGISTRY = { - 'default_': torch_default_param_init_fn_, - 'baseline_': baseline_param_init_fn_, - 'kaiming_uniform_': kaiming_uniform_param_init_fn_, - 'kaiming_normal_': kaiming_normal_param_init_fn_, - 'neox_init_': neox_param_init_fn_, - 'small_init_': small_param_init_fn_, - 'xavier_uniform_': xavier_uniform_param_init_fn_, - 'xavier_normal_': xavier_normal_param_init_fn_, -} +param_init_fns.register('default_', func=torch_default_param_init_fn_) +param_init_fns.register('baseline_', func=baseline_param_init_fn_) +param_init_fns.register('kaiming_uniform_', func=kaiming_uniform_param_init_fn_) +param_init_fns.register('kaiming_normal_', func=kaiming_normal_param_init_fn_) +param_init_fns.register('neox_init_', func=neox_param_init_fn_) +param_init_fns.register('small_init_', func=small_param_init_fn_) +param_init_fns.register('xavier_uniform_', func=xavier_uniform_param_init_fn_) +param_init_fns.register('xavier_normal_', func=xavier_normal_param_init_fn_) + +module_init_fns.register('fc', func=fc_init) +module_init_fns.register('embedding', func=embedding_init) +module_init_fns.register('norm', func=norm_init) +module_init_fns.register('multihead_attention', func=multihead_attention_init) +module_init_fns.register('te_layernorm_mlp', func=te_layernorm_mlp_init) diff --git a/tests/models/utils/test_param_init_fns.py b/tests/models/utils/test_param_init_fns.py index 6be2c5ca42..0efc245602 100644 --- a/tests/models/utils/test_param_init_fns.py +++ b/tests/models/utils/test_param_init_fns.py @@ -12,7 +12,8 @@ from omegaconf import OmegaConf as om from torch import nn -from llmfoundry.models.utils import MODEL_INIT_REGISTRY, generic_param_init_fn_ +from llmfoundry.layers_registry import param_init_fns +from llmfoundry.models.utils import generic_param_init_fn_ class MLP(nn.Module): @@ -150,7 +151,7 @@ def test_emb_init(emb_init_cfg: Optional[Tuple[str, Union[int, List[int]]]]): bias=True)), ])) - model.apply(partial(MODEL_INIT_REGISTRY['kaiming_normal_'], **dict_cfg)) + model.apply(partial(param_init_fns.get('kaiming_normal_'), **dict_cfg)) assert isinstance(model.emb, torch.nn.Embedding) From 406df70893c27d4f71bdb6644debbe7e53922596 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 3 Apr 2024 22:41:03 -0700 Subject: [PATCH 02/17] temp test --- llmfoundry/models/utils/param_init_fns.py | 52 +++++++++++------------ 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index 88618be7c9..b05a537f60 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -255,32 +255,32 @@ def generic_param_init_fn_( f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}' ) - all_module_init_fns = [ - module_init_fns.get(name) for name in module_init_fns.get_all() - ] - did_init = False - for module_init_fn in all_module_init_fns: - did_init = module_init_fn( - module=module, - init_fn_=init_fn_, - d_model=d_model, - init_div_is_residual=init_div_is_residual, - div_is_residual=div_is_residual, - emb_init_std=emb_init_std, - emb_init_uniform_lim=emb_init_uniform_lim, - ) - - if did_init: - break - - if not did_init: - for _ in module.parameters(recurse=False): - # raise error if uninitialized module has any parameters - raise NotImplementedError( - f'{module.__class__.__name__} parameters are not initialized by any of the registered module_init_fns. ' - + - 'Please add an appropriate module_init_fn to the registry. Currently registered module_init_fns are: ' - + ', '.join(module_init_fns.get_all())) + # all_module_init_fns = [ + # module_init_fns.get(name) for name in module_init_fns.get_all() + # ] + # did_init = False + # for module_init_fn in all_module_init_fns: + # did_init = module_init_fn( + # module=module, + # init_fn_=init_fn_, + # d_model=d_model, + # init_div_is_residual=init_div_is_residual, + # div_is_residual=div_is_residual, + # emb_init_std=emb_init_std, + # emb_init_uniform_lim=emb_init_uniform_lim, + # ) + + # if did_init: + # break + + # if not did_init: + # for _ in module.parameters(recurse=False): + # # raise error if uninitialized module has any parameters + # raise NotImplementedError( + # f'{module.__class__.__name__} parameters are not initialized by any of the registered module_init_fns. ' + # + + # 'Please add an appropriate module_init_fn to the registry. Currently registered module_init_fns are: ' + # + ', '.join(module_init_fns.get_all())) def _normal_init_(std: float, mean: float = 0.0) -> Callable: From 0d612ca9a044ce91ee4e404cb52bd337870c781c Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 3 Apr 2024 22:44:19 -0700 Subject: [PATCH 03/17] put it back --- llmfoundry/models/utils/param_init_fns.py | 52 +++++++++++------------ 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index b05a537f60..88618be7c9 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -255,32 +255,32 @@ def generic_param_init_fn_( f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}' ) - # all_module_init_fns = [ - # module_init_fns.get(name) for name in module_init_fns.get_all() - # ] - # did_init = False - # for module_init_fn in all_module_init_fns: - # did_init = module_init_fn( - # module=module, - # init_fn_=init_fn_, - # d_model=d_model, - # init_div_is_residual=init_div_is_residual, - # div_is_residual=div_is_residual, - # emb_init_std=emb_init_std, - # emb_init_uniform_lim=emb_init_uniform_lim, - # ) - - # if did_init: - # break - - # if not did_init: - # for _ in module.parameters(recurse=False): - # # raise error if uninitialized module has any parameters - # raise NotImplementedError( - # f'{module.__class__.__name__} parameters are not initialized by any of the registered module_init_fns. ' - # + - # 'Please add an appropriate module_init_fn to the registry. Currently registered module_init_fns are: ' - # + ', '.join(module_init_fns.get_all())) + all_module_init_fns = [ + module_init_fns.get(name) for name in module_init_fns.get_all() + ] + did_init = False + for module_init_fn in all_module_init_fns: + did_init = module_init_fn( + module=module, + init_fn_=init_fn_, + d_model=d_model, + init_div_is_residual=init_div_is_residual, + div_is_residual=div_is_residual, + emb_init_std=emb_init_std, + emb_init_uniform_lim=emb_init_uniform_lim, + ) + + if did_init: + break + + if not did_init: + for _ in module.parameters(recurse=False): + # raise error if uninitialized module has any parameters + raise NotImplementedError( + f'{module.__class__.__name__} parameters are not initialized by any of the registered module_init_fns. ' + + + 'Please add an appropriate module_init_fn to the registry. Currently registered module_init_fns are: ' + + ', '.join(module_init_fns.get_all())) def _normal_init_(std: float, mean: float = 0.0) -> Callable: From 07968c4fee0a25ea700e48bdbc8beaedc6852623 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 4 Apr 2024 19:03:22 -0700 Subject: [PATCH 04/17] clean up --- llmfoundry/layers_registry.py | 6 +++++- llmfoundry/registry.py | 4 +++- tests/test_registry.py | 2 ++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/llmfoundry/layers_registry.py b/llmfoundry/layers_registry.py index 855cf351fa..c28ea5f20b 100644 --- a/llmfoundry/layers_registry.py +++ b/llmfoundry/layers_registry.py @@ -15,7 +15,11 @@ entry_points=True, description=_norm_description) -_param_init_fns_description = """The param_init_fns registry is used to register functions that initialize parameters.""" +_param_init_fns_description = ( + 'The param_init_fns registry is used to register functions that initialize parameters.' + + + 'These will be called on a module to initialize its parameters. See param_init_fns.py for examples.' +) param_init_fns = create_registry('llmfoundry', 'param_init_fns', generic_type=Callable[..., None], diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index 424075da3b..d9ad085e2f 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -12,7 +12,7 @@ from transformers import PreTrainedTokenizerBase from llmfoundry.interfaces import CallbackWithConfig -from llmfoundry.layers_registry import norms +from llmfoundry.layers_registry import module_init_fns, norms, param_init_fns from llmfoundry.utils.registry_utils import create_registry _loggers_description = ( @@ -121,4 +121,6 @@ 'metrics', 'dataloaders', 'norms', + 'param_init_fns', + 'module_init_fns', ] diff --git a/tests/test_registry.py b/tests/test_registry.py index c93c7c9749..c6b12466f5 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -31,6 +31,8 @@ def test_expected_registries_exist(): 'metrics', 'models', 'norms', + 'param_init_fns', + 'module_init_fns', } assert existing_registries == expected_registry_names From 7b22efcabd69acd4982cc544d39dd1af7c3ea642 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 20:17:44 -0700 Subject: [PATCH 05/17] pc --- llmfoundry/models/utils/param_init_fns.py | 78 ++++++++++++++--------- 1 file changed, 47 insertions(+), 31 deletions(-) diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index d2777c73d1..e64cde5e96 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -317,6 +317,52 @@ def te_layernorm_mlp_init( return False +def moe_init( + module: nn.Module, + init_fn_: Callable, + init_div_is_residual: Union[int, float, str, bool], + div_is_residual: float, + **kwargs: Any, +) -> bool: + if megablocks is not None and isinstance(module, ( + megablocks.layers.moe.MoE, + megablocks.layers.dmoe.dMoE, + megablocks.layers.moe.ParallelMLP, + megablocks.layers.dmoe.ParallelDroplessMLP, + )): + if hasattr(module, 'bias') and module.bias is not None: + # Initialize bias to 0 + torch.nn.init.zeros_(module.bias) # type: ignore + return True + elif megablocks is not None and isinstance(module, + megablocks.layers.glu.SparseGLU): + _megablocks_sparse_glu_generic_param_init_fn_( + module, init_fn_, bool(init_div_is_residual), div_is_residual) + return True + elif megablocks is not None and isinstance(module, + megablocks.layers.mlp.SparseMLP): + _megablocks_sparse_mlp_generic_param_init_fn_( + module, init_fn_, bool(init_div_is_residual), div_is_residual) + return True + elif megablocks is not None and isinstance(module, + megablocks.layers.mlp.MLP): + _megablocks_mlp_generic_param_init_fn_(module, init_fn_, + bool(init_div_is_residual), + div_is_residual) + return True + elif isinstance(module, GLU): + init_fn_(module.w1) + init_fn_(module.v1) + init_fn_(module.w2) + return True + elif isinstance(module, MLP): + init_fn_(module.w1) + init_fn_(module.w2) + return True + + return False + + def generic_param_init_fn_( module: nn.Module, init_fn_: Callable, @@ -351,37 +397,6 @@ def generic_param_init_fn_( f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}' ) - # TODO: FINISH MERGING - # elif megablocks is not None and isinstance(module, ( - # megablocks.layers.moe.MoE, - # megablocks.layers.dmoe.dMoE, - # megablocks.layers.moe.ParallelMLP, - # megablocks.layers.dmoe.ParallelDroplessMLP, - # )): - # if hasattr(module, 'bias') and module.bias is not None: - # # Initialize bias to 0 - # torch.nn.init.zeros_(module.bias) # type: ignore - # elif megablocks is not None and isinstance(module, - # megablocks.layers.glu.SparseGLU): - # _megablocks_sparse_glu_generic_param_init_fn_( - # module, init_fn_, bool(init_div_is_residual), div_is_residual) - # elif megablocks is not None and isinstance(module, - # megablocks.layers.mlp.SparseMLP): - # _megablocks_sparse_mlp_generic_param_init_fn_( - # module, init_fn_, bool(init_div_is_residual), div_is_residual) - # elif megablocks is not None and isinstance(module, - # megablocks.layers.mlp.MLP): - # _megablocks_mlp_generic_param_init_fn_(module, init_fn_, - # bool(init_div_is_residual), - # div_is_residual) - # elif isinstance(module, GLU): - # init_fn_(module.w1) - # init_fn_(module.v1) - # init_fn_(module.w2) - # elif isinstance(module, MLP): - # init_fn_(module.w1) - # init_fn_(module.w2) - all_module_init_fns = [ module_init_fns.get(name) for name in module_init_fns.get_all() ] @@ -827,3 +842,4 @@ def xavier_normal_param_init_fn_( module_init_fns.register('norm', func=norm_init) module_init_fns.register('multihead_attention', func=multihead_attention_init) module_init_fns.register('te_layernorm_mlp', func=te_layernorm_mlp_init) +module_init_fns.register('moe', func=moe_init) From a088fbfe5a8ffd806b2b41ee99c8c1bcb0b5ade0 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 12 Apr 2024 05:04:12 +0000 Subject: [PATCH 06/17] logs --- llmfoundry/models/utils/param_init_fns.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index e64cde5e96..62ca74c29f 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -324,6 +324,10 @@ def moe_init( div_is_residual: float, **kwargs: Any, ) -> bool: + print('in moe init') + print(type(module)) + print(isinstance(module, GLU)) + print(isinstance(module, MLP)) if megablocks is not None and isinstance(module, ( megablocks.layers.moe.MoE, megablocks.layers.dmoe.dMoE, @@ -400,8 +404,11 @@ def generic_param_init_fn_( all_module_init_fns = [ module_init_fns.get(name) for name in module_init_fns.get_all() ] + print('in init') + print(all_module_init_fns) did_init = False for module_init_fn in all_module_init_fns: + print(type(module), module_init_fn) did_init = module_init_fn( module=module, init_fn_=init_fn_, From 66886158b65ce099689f69d50aeebbd42e49db85 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 12 Apr 2024 05:30:30 +0000 Subject: [PATCH 07/17] more --- llmfoundry/models/utils/param_init_fns.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index 62ca74c29f..d323155388 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -328,6 +328,8 @@ def moe_init( print(type(module)) print(isinstance(module, GLU)) print(isinstance(module, MLP)) + print(GLU) + print(MLP) if megablocks is not None and isinstance(module, ( megablocks.layers.moe.MoE, megablocks.layers.dmoe.dMoE, From 8981209546ce9ebc6407f8c518df69df2fec7951 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 12 Apr 2024 05:34:13 +0000 Subject: [PATCH 08/17] try del cache --- tests/models/layers/test_dmoe.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/models/layers/test_dmoe.py b/tests/models/layers/test_dmoe.py index 9c15745793..082a1ad02a 100644 --- a/tests/models/layers/test_dmoe.py +++ b/tests/models/layers/test_dmoe.py @@ -7,6 +7,8 @@ from typing import List, Optional import pytest +import shutil +import os import torch import torch.distributed as dist import torch.nn.functional as F @@ -187,6 +189,21 @@ def test_dmoe(moe_num_experts: int, mlp_type: str, moe_world_size: int, mb_y = mb_dmoe(x) torch.testing.assert_close(torch_y, mb_y) +# TODO(GRT-2435): Change to fixture +def delete_transformers_cache(): + # Only delete the files on local rank 0, otherwise race conditions are created + if not dist.get_local_rank() == 0: + return + + hf_cache_home = os.path.expanduser( + os.getenv( + 'HF_HOME', + os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), + 'huggingface'))) + HF_MODULES_CACHE = os.getenv('HF_MODULES_CACHE', + os.path.join(hf_cache_home, 'modules')) + if os.path.exists(HF_MODULES_CACHE) and os.path.isdir(HF_MODULES_CACHE): + shutil.rmtree(HF_MODULES_CACHE) @pytest.mark.skipif(not is_megablocks_imported, reason='This test needs megablocks module') @@ -195,6 +212,8 @@ def test_dmoe(moe_num_experts: int, mlp_type: str, moe_world_size: int, @pytest.mark.parametrize('mlp_type', ['glu', 'mlp']) @pytest.mark.parametrize('precision', ['bf16', 'fp32']) def test_fwd_equal_dmoe(seqlen: int, precision: str, mlp_type: str): + delete_transformers_cache() + mb_dmoe_config = MPTConfig(d_model=1024, n_heads=32, n_layers=1, @@ -261,3 +280,5 @@ def test_fwd_equal_dmoe(seqlen: int, precision: str, mlp_type: str): mpt_logits = mb_dmoe_model(token_ids).logits db_logits = torch_dmoe_model(token_ids).logits assert torch.allclose(mpt_logits, db_logits, rtol=0.01, atol=0.01) + + delete_transformers_cache() From 4fe928c164b6198671c727f7e952fe0cf3e98163 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 12 Apr 2024 05:36:12 +0000 Subject: [PATCH 09/17] fix --- tests/models/layers/test_dmoe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/models/layers/test_dmoe.py b/tests/models/layers/test_dmoe.py index 082a1ad02a..10bbe6427e 100644 --- a/tests/models/layers/test_dmoe.py +++ b/tests/models/layers/test_dmoe.py @@ -20,6 +20,8 @@ from torch.distributed.tensor.parallel.ddp import _pre_dp_module_transform from torch.nn.parallel import DistributedDataParallel as DDP +from composer.utils import dist as cdist + from llmfoundry.models.layers.dmoe import dMoE from llmfoundry.models.layers.ffn import dtensorify_param from llmfoundry.models.mpt.configuration_mpt import MPTConfig @@ -192,7 +194,7 @@ def test_dmoe(moe_num_experts: int, mlp_type: str, moe_world_size: int, # TODO(GRT-2435): Change to fixture def delete_transformers_cache(): # Only delete the files on local rank 0, otherwise race conditions are created - if not dist.get_local_rank() == 0: + if not cdist.get_local_rank() == 0: return hf_cache_home = os.path.expanduser( From 88a65115f26aa2f9fa50272e5c024bb6ec805c89 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 22:37:20 -0700 Subject: [PATCH 10/17] fix --- tests/models/layers/test_dmoe.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/models/layers/test_dmoe.py b/tests/models/layers/test_dmoe.py index 10bbe6427e..328140d4a3 100644 --- a/tests/models/layers/test_dmoe.py +++ b/tests/models/layers/test_dmoe.py @@ -2,17 +2,18 @@ # SPDX-License-Identifier: Apache-2.0 import copy +import os +import shutil from contextlib import nullcontext from functools import partial from typing import List, Optional import pytest -import shutil -import os import torch import torch.distributed as dist import torch.nn.functional as F import torch.optim as optim +from composer.utils import dist as cdist from torch.distributed._tensor import DTensor, Placement, Replicate, Shard from torch.distributed._tensor.device_mesh import init_device_mesh from torch.distributed.checkpoint.state_dict import (StateDictOptions, @@ -20,8 +21,6 @@ from torch.distributed.tensor.parallel.ddp import _pre_dp_module_transform from torch.nn.parallel import DistributedDataParallel as DDP -from composer.utils import dist as cdist - from llmfoundry.models.layers.dmoe import dMoE from llmfoundry.models.layers.ffn import dtensorify_param from llmfoundry.models.mpt.configuration_mpt import MPTConfig @@ -191,6 +190,7 @@ def test_dmoe(moe_num_experts: int, mlp_type: str, moe_world_size: int, mb_y = mb_dmoe(x) torch.testing.assert_close(torch_y, mb_y) + # TODO(GRT-2435): Change to fixture def delete_transformers_cache(): # Only delete the files on local rank 0, otherwise race conditions are created @@ -207,6 +207,7 @@ def delete_transformers_cache(): if os.path.exists(HF_MODULES_CACHE) and os.path.isdir(HF_MODULES_CACHE): shutil.rmtree(HF_MODULES_CACHE) + @pytest.mark.skipif(not is_megablocks_imported, reason='This test needs megablocks module') @pytest.mark.gpu From bc87d1dd1b4f0fe44a45ce07a27a7d277c218603 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 12 Apr 2024 06:29:55 +0000 Subject: [PATCH 11/17] test --- tests/models/layers/test_dmoe.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/models/layers/test_dmoe.py b/tests/models/layers/test_dmoe.py index 328140d4a3..58d110f9c4 100644 --- a/tests/models/layers/test_dmoe.py +++ b/tests/models/layers/test_dmoe.py @@ -217,6 +217,13 @@ def delete_transformers_cache(): def test_fwd_equal_dmoe(seqlen: int, precision: str, mlp_type: str): delete_transformers_cache() + import llmfoundry + print(llmfoundry.layers_registry.ffns.get_all()) + + from llmfoundry.models.layers.ffn import resolve_ffn_act_fn # type: ignore + + print(llmfoundry.layers_registry.ffns.get_all()) + mb_dmoe_config = MPTConfig(d_model=1024, n_heads=32, n_layers=1, From d110f74cae0c920fa780e7c8cb1d14a8f8eed0f4 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 12 Apr 2024 06:33:36 +0000 Subject: [PATCH 12/17] fix --- tests/models/layers/test_dmoe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/layers/test_dmoe.py b/tests/models/layers/test_dmoe.py index 58d110f9c4..e1bfd97939 100644 --- a/tests/models/layers/test_dmoe.py +++ b/tests/models/layers/test_dmoe.py @@ -217,12 +217,12 @@ def delete_transformers_cache(): def test_fwd_equal_dmoe(seqlen: int, precision: str, mlp_type: str): delete_transformers_cache() - import llmfoundry - print(llmfoundry.layers_registry.ffns.get_all()) + from llmfoundry.layers_registry import module_init_fns + print(module_init_fns.get_all()) from llmfoundry.models.layers.ffn import resolve_ffn_act_fn # type: ignore - print(llmfoundry.layers_registry.ffns.get_all()) + print(module_init_fns.get_all()) mb_dmoe_config = MPTConfig(d_model=1024, n_heads=32, From c4dd2fdf4b4a2024b31286d42f4fdfb797759499 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 12 Apr 2024 07:10:38 +0000 Subject: [PATCH 13/17] maybe fix --- llmfoundry/utils/registry_utils.py | 10 ++++++++++ tests/fixtures/autouse.py | 6 ++++++ 2 files changed, 16 insertions(+) diff --git a/llmfoundry/utils/registry_utils.py b/llmfoundry/utils/registry_utils.py index d9c23e6f26..d7861cc557 100644 --- a/llmfoundry/utils/registry_utils.py +++ b/llmfoundry/utils/registry_utils.py @@ -3,6 +3,7 @@ import functools import importlib.util +from contextlib import contextmanager import os from pathlib import Path from types import ModuleType @@ -174,3 +175,12 @@ def import_file(loc: Union[str, Path]) -> ModuleType: except Exception as e: raise RuntimeError(f'Error executing {loc}') from e return module + +@contextmanager +def save_registry(): + """Save the registry state and restore after the context manager exits.""" + saved_registry_state = copy.deepcopy(catalogue.REGISTRY) + + yield + + catalogue.REGISTRY = saved_registry_state \ No newline at end of file diff --git a/tests/fixtures/autouse.py b/tests/fixtures/autouse.py index ccbe1b69f7..6d6a5ad006 100644 --- a/tests/fixtures/autouse.py +++ b/tests/fixtures/autouse.py @@ -9,10 +9,16 @@ import torch from composer.utils import dist, get_device, reproducibility +from llmfoundry.utils.registry_utils import save_registry + # Add llm-foundry repo root to path so we can import scripts in the tests REPO_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) sys.path.append(REPO_DIR) +@pytest.fixture(autouse=True) +def save_registry_fixture(): + with save_registry(): + yield @pytest.fixture(autouse=True) def initialize_dist(request: pytest.FixtureRequest): From 882899accbd4a39b31fbabb9df6565f18fbcfe1e Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 12 Apr 2024 00:11:56 -0700 Subject: [PATCH 14/17] pc --- llmfoundry/utils/registry_utils.py | 6 ++++-- tests/fixtures/autouse.py | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/llmfoundry/utils/registry_utils.py b/llmfoundry/utils/registry_utils.py index d7861cc557..1e65963f8a 100644 --- a/llmfoundry/utils/registry_utils.py +++ b/llmfoundry/utils/registry_utils.py @@ -1,10 +1,11 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import copy import functools import importlib.util -from contextlib import contextmanager import os +from contextlib import contextmanager from pathlib import Path from types import ModuleType from typing import (Any, Callable, Dict, Generic, Optional, Sequence, Type, @@ -176,6 +177,7 @@ def import_file(loc: Union[str, Path]) -> ModuleType: raise RuntimeError(f'Error executing {loc}') from e return module + @contextmanager def save_registry(): """Save the registry state and restore after the context manager exits.""" @@ -183,4 +185,4 @@ def save_registry(): yield - catalogue.REGISTRY = saved_registry_state \ No newline at end of file + catalogue.REGISTRY = saved_registry_state diff --git a/tests/fixtures/autouse.py b/tests/fixtures/autouse.py index 6d6a5ad006..16e3f8ad6f 100644 --- a/tests/fixtures/autouse.py +++ b/tests/fixtures/autouse.py @@ -15,11 +15,13 @@ REPO_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) sys.path.append(REPO_DIR) + @pytest.fixture(autouse=True) def save_registry_fixture(): with save_registry(): yield + @pytest.fixture(autouse=True) def initialize_dist(request: pytest.FixtureRequest): """Initialize the default PyTorch distributed process group for tests.""" From ef1b964669292f234e609f1ad33552b92799682b Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 12 Apr 2024 00:31:28 -0700 Subject: [PATCH 15/17] remove prints --- llmfoundry/models/utils/param_init_fns.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index d323155388..7ddac5621c 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -324,12 +324,6 @@ def moe_init( div_is_residual: float, **kwargs: Any, ) -> bool: - print('in moe init') - print(type(module)) - print(isinstance(module, GLU)) - print(isinstance(module, MLP)) - print(GLU) - print(MLP) if megablocks is not None and isinstance(module, ( megablocks.layers.moe.MoE, megablocks.layers.dmoe.dMoE, From f07248769707b14fb49083d5ae2c54b90a0cb92e Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 12 Apr 2024 00:33:16 -0700 Subject: [PATCH 16/17] more prints --- llmfoundry/models/utils/param_init_fns.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index 7ddac5621c..e64cde5e96 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -400,11 +400,8 @@ def generic_param_init_fn_( all_module_init_fns = [ module_init_fns.get(name) for name in module_init_fns.get_all() ] - print('in init') - print(all_module_init_fns) did_init = False for module_init_fn in all_module_init_fns: - print(type(module), module_init_fn) did_init = module_init_fn( module=module, init_fn_=init_fn_, From 4ead37cbe1275b1c97ef3a48137a169faba857b9 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 12 Apr 2024 00:35:57 -0700 Subject: [PATCH 17/17] more debug cleanup --- tests/models/layers/test_dmoe.py | 31 ------------------------------- 1 file changed, 31 deletions(-) diff --git a/tests/models/layers/test_dmoe.py b/tests/models/layers/test_dmoe.py index e1bfd97939..9c15745793 100644 --- a/tests/models/layers/test_dmoe.py +++ b/tests/models/layers/test_dmoe.py @@ -2,8 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import copy -import os -import shutil from contextlib import nullcontext from functools import partial from typing import List, Optional @@ -13,7 +11,6 @@ import torch.distributed as dist import torch.nn.functional as F import torch.optim as optim -from composer.utils import dist as cdist from torch.distributed._tensor import DTensor, Placement, Replicate, Shard from torch.distributed._tensor.device_mesh import init_device_mesh from torch.distributed.checkpoint.state_dict import (StateDictOptions, @@ -191,23 +188,6 @@ def test_dmoe(moe_num_experts: int, mlp_type: str, moe_world_size: int, torch.testing.assert_close(torch_y, mb_y) -# TODO(GRT-2435): Change to fixture -def delete_transformers_cache(): - # Only delete the files on local rank 0, otherwise race conditions are created - if not cdist.get_local_rank() == 0: - return - - hf_cache_home = os.path.expanduser( - os.getenv( - 'HF_HOME', - os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), - 'huggingface'))) - HF_MODULES_CACHE = os.getenv('HF_MODULES_CACHE', - os.path.join(hf_cache_home, 'modules')) - if os.path.exists(HF_MODULES_CACHE) and os.path.isdir(HF_MODULES_CACHE): - shutil.rmtree(HF_MODULES_CACHE) - - @pytest.mark.skipif(not is_megablocks_imported, reason='This test needs megablocks module') @pytest.mark.gpu @@ -215,15 +195,6 @@ def delete_transformers_cache(): @pytest.mark.parametrize('mlp_type', ['glu', 'mlp']) @pytest.mark.parametrize('precision', ['bf16', 'fp32']) def test_fwd_equal_dmoe(seqlen: int, precision: str, mlp_type: str): - delete_transformers_cache() - - from llmfoundry.layers_registry import module_init_fns - print(module_init_fns.get_all()) - - from llmfoundry.models.layers.ffn import resolve_ffn_act_fn # type: ignore - - print(module_init_fns.get_all()) - mb_dmoe_config = MPTConfig(d_model=1024, n_heads=32, n_layers=1, @@ -290,5 +261,3 @@ def test_fwd_equal_dmoe(seqlen: int, precision: str, mlp_type: str): mpt_logits = mb_dmoe_model(token_ids).logits db_logits = torch_dmoe_model(token_ids).logits assert torch.allclose(mpt_logits, db_logits, rtol=0.01, atol=0.01) - - delete_transformers_cache()