diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 63f76228e8..9d18799e93 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -75,10 +75,11 @@ mpt_get_total_params, ) -# Import the fcs here so that recursive code creating the files for hf checkpoints can find them -# This is the only exception because fc.py is not imported in any other place in the codebase +# Import the fcs and param_init_fns here so that the recursive code creating the files for hf checkpoints can find them +# These are the exceptions because fc.py and param_init_fns.py are not imported in any other place in the import tree # isort: off from llmfoundry.models.layers.fc import fcs # type: ignore +from llmfoundry.models.utils.param_init_fns import generic_param_init_fn_ # type: ignore # isort: on log = logging.getLogger(__name__)