diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index b05a537f60..88618be7c9 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -255,32 +255,32 @@ def generic_param_init_fn_( f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}' ) - # all_module_init_fns = [ - # module_init_fns.get(name) for name in module_init_fns.get_all() - # ] - # did_init = False - # for module_init_fn in all_module_init_fns: - # did_init = module_init_fn( - # module=module, - # init_fn_=init_fn_, - # d_model=d_model, - # init_div_is_residual=init_div_is_residual, - # div_is_residual=div_is_residual, - # emb_init_std=emb_init_std, - # emb_init_uniform_lim=emb_init_uniform_lim, - # ) - - # if did_init: - # break - - # if not did_init: - # for _ in module.parameters(recurse=False): - # # raise error if uninitialized module has any parameters - # raise NotImplementedError( - # f'{module.__class__.__name__} parameters are not initialized by any of the registered module_init_fns. ' - # + - # 'Please add an appropriate module_init_fn to the registry. Currently registered module_init_fns are: ' - # + ', '.join(module_init_fns.get_all())) + all_module_init_fns = [ + module_init_fns.get(name) for name in module_init_fns.get_all() + ] + did_init = False + for module_init_fn in all_module_init_fns: + did_init = module_init_fn( + module=module, + init_fn_=init_fn_, + d_model=d_model, + init_div_is_residual=init_div_is_residual, + div_is_residual=div_is_residual, + emb_init_std=emb_init_std, + emb_init_uniform_lim=emb_init_uniform_lim, + ) + + if did_init: + break + + if not did_init: + for _ in module.parameters(recurse=False): + # raise error if uninitialized module has any parameters + raise NotImplementedError( + f'{module.__class__.__name__} parameters are not initialized by any of the registered module_init_fns. ' + + + 'Please add an appropriate module_init_fn to the registry. Currently registered module_init_fns are: ' + + ', '.join(module_init_fns.get_all())) def _normal_init_(std: float, mean: float = 0.0) -> Callable: