diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 6f7361f27f..63f76228e8 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -75,6 +75,12 @@ mpt_get_total_params, ) +# Import the fcs here so that recursive code creating the files for hf checkpoints can find them +# This is the only exception because fc.py is not imported in any other place in the codebase +# isort: off +from llmfoundry.models.layers.fc import fcs # type: ignore +# isort: on + log = logging.getLogger(__name__)