diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 2fb2bdf662..35db97ca3f 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -20,6 +20,7 @@ from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note) from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note) from llmfoundry.models.layers.layer_builders import build_norm # type: ignore (see note) +from llmfoundry.layers_registry import norms # type: ignore (see note) ffn_config_defaults: Dict = { 'ffn_type': 'mptmlp',