Skip to content

Commit

Permalink
Merge branch 'main' into openai_compatible_gauntlet
Browse files Browse the repository at this point in the history
  • Loading branch information
maxisawesome committed Apr 13, 2024
2 parents f493e35 + f01f625 commit e62f584
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 67 deletions.
3 changes: 0 additions & 3 deletions llmfoundry/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Natively supported datasets."""

from llmfoundry.eval.datasets import (
Expand Down
22 changes: 22 additions & 0 deletions llmfoundry/layers_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,30 @@
entry_points=True,
description=_attention_implementations_description)

_param_init_fns_description = (
'The param_init_fns registry is used to register functions that initialize parameters.'
+
'These will be called on a module to initialize its parameters. See param_init_fns.py for examples.'
)
param_init_fns = create_registry('llmfoundry',
'param_init_fns',
generic_type=Callable[..., None],
entry_points=True,
description=_param_init_fns_description)

_module_init_fns_description = """The module_init_fns registry is used to register functions that initialize specific modules.
These functions should return True if they initialize the module, and False otherwise. This allows them to be called without knowing their contents.
They should take in the module, init_div_is_residual, and div_is_residual arguments."""
module_init_fns = create_registry('llmfoundry',
'module_init_fns',
generic_type=Callable[..., bool],
entry_points=True,
description=_module_init_fns_description)

__all__ = [
'norms',
'param_init_fns',
'module_init_fns',
'ffns',
'ffns_with_norm',
'ffns_with_megablocks',
Expand Down
7 changes: 3 additions & 4 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from transformers.models.llama.modeling_llama import \
LlamaRotaryEmbedding as HFRotaryEmbedding

from llmfoundry.layers_registry import norms
from llmfoundry.layers_registry import norms, param_init_fns
from llmfoundry.models.layers.attention import (attn_bias_shape,
build_attn_bias, gen_slopes)
from llmfoundry.models.layers.blocks import MPTBlock
Expand All @@ -62,7 +62,6 @@
init_empty_weights # type: ignore (see note)
from llmfoundry.models.utils.param_init_fns import (
generic_param_init_fn_, # type: ignore (see note)
MODEL_INIT_REGISTRY,
)
from llmfoundry.models.layers.ffn import resolve_ffn_act_fn # type: ignore (see note)

Expand Down Expand Up @@ -678,7 +677,7 @@ def forward(
# Param Initialization, needed for device='meta' fast initialization
def param_init_fn(self, module: nn.Module) -> None:
init_fn_name = self.config.init_config['name']
MODEL_INIT_REGISTRY[init_fn_name](
param_init_fns.get(init_fn_name)(
module=module,
n_layers=self.config.n_layers,
d_model=self.config.d_model,
Expand Down Expand Up @@ -838,7 +837,7 @@ def forward(
# Param Initialization, needed for device='meta' fast initialization
def param_init_fn(self, module: nn.Module) -> None:
init_fn_name = self.config.init_config['name']
MODEL_INIT_REGISTRY[init_fn_name](
param_init_fns.get(init_fn_name)(
module=module,
n_layers=self.config.n_layers,
d_model=self.config.d_model,
Expand Down
4 changes: 1 addition & 3 deletions llmfoundry/models/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
init_on_device)
from llmfoundry.models.utils.mpt_param_count import (mpt_get_active_params,
mpt_get_total_params)
from llmfoundry.models.utils.param_init_fns import (MODEL_INIT_REGISTRY,
generic_param_init_fn_)
from llmfoundry.models.utils.param_init_fns import generic_param_init_fn_

__all__ = [
'init_empty_weights',
'init_on_device',
'generic_param_init_fn_',
'MODEL_INIT_REGISTRY',
'config_moe_args',
'mpt_get_active_params',
'mpt_get_total_params',
Expand Down
Loading

0 comments on commit e62f584

Please sign in to comment.