From 406df70893c27d4f71bdb6644debbe7e53922596 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 3 Apr 2024 22:41:03 -0700 Subject: [PATCH] temp test --- llmfoundry/models/utils/param_init_fns.py | 52 +++++++++++------------ 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index 88618be7c9..b05a537f60 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -255,32 +255,32 @@ def generic_param_init_fn_( f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}' ) - all_module_init_fns = [ - module_init_fns.get(name) for name in module_init_fns.get_all() - ] - did_init = False - for module_init_fn in all_module_init_fns: - did_init = module_init_fn( - module=module, - init_fn_=init_fn_, - d_model=d_model, - init_div_is_residual=init_div_is_residual, - div_is_residual=div_is_residual, - emb_init_std=emb_init_std, - emb_init_uniform_lim=emb_init_uniform_lim, - ) - - if did_init: - break - - if not did_init: - for _ in module.parameters(recurse=False): - # raise error if uninitialized module has any parameters - raise NotImplementedError( - f'{module.__class__.__name__} parameters are not initialized by any of the registered module_init_fns. ' - + - 'Please add an appropriate module_init_fn to the registry. Currently registered module_init_fns are: ' - + ', '.join(module_init_fns.get_all())) + # all_module_init_fns = [ + # module_init_fns.get(name) for name in module_init_fns.get_all() + # ] + # did_init = False + # for module_init_fn in all_module_init_fns: + # did_init = module_init_fn( + # module=module, + # init_fn_=init_fn_, + # d_model=d_model, + # init_div_is_residual=init_div_is_residual, + # div_is_residual=div_is_residual, + # emb_init_std=emb_init_std, + # emb_init_uniform_lim=emb_init_uniform_lim, + # ) + + # if did_init: + # break + + # if not did_init: + # for _ in module.parameters(recurse=False): + # # raise error if uninitialized module has any parameters + # raise NotImplementedError( + # f'{module.__class__.__name__} parameters are not initialized by any of the registered module_init_fns. ' + # + + # 'Please add an appropriate module_init_fn to the registry. Currently registered module_init_fns are: ' + # + ', '.join(module_init_fns.get_all())) def _normal_init_(std: float, mean: float = 0.0) -> Callable: