Skip to content

Commit

Permalink
put it back
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Apr 5, 2024
1 parent 406df70 commit 0d612ca
Showing 1 changed file with 26 additions and 26 deletions.
52 changes: 26 additions & 26 deletions llmfoundry/models/utils/param_init_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,32 +255,32 @@ def generic_param_init_fn_(
f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}'
)

# all_module_init_fns = [
# module_init_fns.get(name) for name in module_init_fns.get_all()
# ]
# did_init = False
# for module_init_fn in all_module_init_fns:
# did_init = module_init_fn(
# module=module,
# init_fn_=init_fn_,
# d_model=d_model,
# init_div_is_residual=init_div_is_residual,
# div_is_residual=div_is_residual,
# emb_init_std=emb_init_std,
# emb_init_uniform_lim=emb_init_uniform_lim,
# )

# if did_init:
# break

# if not did_init:
# for _ in module.parameters(recurse=False):
# # raise error if uninitialized module has any parameters
# raise NotImplementedError(
# f'{module.__class__.__name__} parameters are not initialized by any of the registered module_init_fns. '
# +
# 'Please add an appropriate module_init_fn to the registry. Currently registered module_init_fns are: '
# + ', '.join(module_init_fns.get_all()))
all_module_init_fns = [
module_init_fns.get(name) for name in module_init_fns.get_all()
]
did_init = False
for module_init_fn in all_module_init_fns:
did_init = module_init_fn(
module=module,
init_fn_=init_fn_,
d_model=d_model,
init_div_is_residual=init_div_is_residual,
div_is_residual=div_is_residual,
emb_init_std=emb_init_std,
emb_init_uniform_lim=emb_init_uniform_lim,
)

if did_init:
break

if not did_init:
for _ in module.parameters(recurse=False):
# raise error if uninitialized module has any parameters
raise NotImplementedError(
f'{module.__class__.__name__} parameters are not initialized by any of the registered module_init_fns. '
+
'Please add an appropriate module_init_fn to the registry. Currently registered module_init_fns are: '
+ ', '.join(module_init_fns.get_all()))


def _normal_init_(std: float, mean: float = 0.0) -> Callable:
Expand Down

0 comments on commit 0d612ca

Please sign in to comment.