diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index d51558f04d..9d18799e93 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -40,6 +40,8 @@ except Exception as e: raise e +import logging + from transformers import PreTrainedModel, PreTrainedTokenizerBase from transformers.modeling_outputs import ( BaseModelOutputWithPast, @@ -62,31 +64,23 @@ from llmfoundry.models.layers.custom_embedding import SharedEmbedding from llmfoundry.models.layers.layer_builders import build_norm from llmfoundry.models.mpt.configuration_mpt import MPTConfig +from llmfoundry.models.utils.act_ckpt import ( + build_act_ckpt_mod_to_blocks, + check_mapping_blocks_overlap, + pass_on_block_idx, +) from llmfoundry.models.utils.config_moe_args import config_moe_args from llmfoundry.models.utils.mpt_param_count import ( mpt_get_active_params, mpt_get_total_params, ) -# NOTE: All utils are imported directly even if unused so that -# HuggingFace can detect all the needed files to copy into its modules folder. -# Otherwise, certain modules are missing. +# Import the fcs and param_init_fns here so that the recursive code creating the files for hf checkpoints can find them +# These are the exceptions because fc.py and param_init_fns.py are not imported in any other place in the import tree # isort: off -from llmfoundry.models.utils.meta_init_context import \ - init_empty_weights # type: ignore (see note) -from llmfoundry.models.utils.param_init_fns import ( - generic_param_init_fn_, # type: ignore (see note) -) -from llmfoundry.models.layers.ffn import resolve_ffn_act_fn # type: ignore (see note) -from llmfoundry.models.layers.fc import fcs # type: ignore (see note) - -from llmfoundry.models.utils.act_ckpt import ( - pass_on_block_idx, - build_act_ckpt_mod_to_blocks, - check_mapping_blocks_overlap, -) - -import logging +from llmfoundry.models.layers.fc import fcs # type: ignore +from llmfoundry.models.utils.param_init_fns import generic_param_init_fn_ # type: ignore +# isort: on log = logging.getLogger(__name__) diff --git a/llmfoundry/utils/huggingface_hub_utils.py b/llmfoundry/utils/huggingface_hub_utils.py index b4eec89cbd..3f7b3a0f55 100644 --- a/llmfoundry/utils/huggingface_hub_utils.py +++ b/llmfoundry/utils/huggingface_hub_utils.py @@ -3,6 +3,7 @@ import ast import importlib +import json import os from typing import Optional, Sequence @@ -139,6 +140,80 @@ def process_file( return new_files_to_process +def get_all_relative_imports(file_path: str) -> set[str]: + """Get all relative imports from a file. + + Args: + file_path (str): The file to process. + + Returns: + set[str]: The relative imports. + """ + with open(file_path, 'r', encoding='utf-8') as f: + source = f.read() + + tree = ast.parse(source) + relative_imports = set() + for node in ast.walk(tree): + if isinstance( + node, + ast.ImportFrom, + ) and node.module is not None and node.level == 1: + relative_imports.add(node.module) + + return relative_imports + + +def add_relative_imports( + file_path: str, + relative_imports_to_add: set[str], +) -> None: + """Add relative imports to a file. + + Args: + file_path (str): The file to add to. + relative_imports_to_add (set[str]): The set of relative imports to add + """ + # Get the directory name where all the files are located + dir_name = os.path.dirname(file_path) + + with open(file_path, 'r', encoding='utf-8') as f: + source = f.read() + + tree = ast.parse(source) + + for relative_import in relative_imports_to_add: + import_path = os.path.join(dir_name, relative_import + '.py') + # Open up the file we are adding an import to find something to import from it + with open(import_path, 'r', encoding='utf-8') as f: + import_source = f.read() + import_tree = ast.parse(import_source) + item_to_import = None + for node in ast.walk(import_tree): + # Find the first function or class + if isinstance(node, + ast.FunctionDef) or isinstance(node, ast.ClassDef): + # Get its name to import it + item_to_import = node.name + break + + if item_to_import is None: + item_to_import = '*' + + # This will look like `from .relative_import import item_to_import` + import_node = ast.ImportFrom( + module=relative_import, + names=[ast.alias(name=item_to_import, asname=None)], + level=1, + ) + + # Insert near the top of the file, but past the from __future__ import + tree.body.insert(2, import_node) + + with open(file_path, 'w', encoding='utf-8') as f: + f.write(ast.unparse(tree)) + + def edit_files_for_hf_compatibility( folder: str, flatten_imports_prefix: Sequence[str] = ('llmfoundry',), @@ -158,9 +233,27 @@ def edit_files_for_hf_compatibility( remove_imports_prefix (Sequence[str], optional): Sequence of prefixes to remove. Takes precedence over flattening. Defaults to ('composer', 'omegaconf', 'llmfoundry.metrics', 'llmfoundry.utils.builders'). """ + listed_dir = os.listdir(folder) + + # Try to acquire the config file to determine which python file is the entrypoint file + config_file_exists = 'config.json' in listed_dir + with open(os.path.join(folder, 'config.json'), 'r') as _f: + config = json.load(_f) + + # If the config file exists, the entrypoint files would be specified in the auto map + entrypoint_files = set() + if config_file_exists: + for key, value in config.get('auto_map', {}).items(): + # Only keep the modeling entrypoints, e.g. AutoModelForCausalLM + if 'model' not in key.lower(): + continue + split_path = value.split('.') + if len(split_path) > 1: + entrypoint_files.add(split_path[0] + '.py') + files_to_process = [ os.path.join(folder, filename) - for filename in os.listdir(folder) + for filename in listed_dir if filename.endswith('.py') ] files_processed_and_queued = set(files_to_process) @@ -178,3 +271,22 @@ def edit_files_for_hf_compatibility( if file not in files_processed_and_queued: files_to_process.append(file) files_processed_and_queued.add(file) + + # For each entrypoint, determine which imports are missing, and add them + # This is because HF does not recursively search imports when determining + # which files to copy into its modules cache + all_relative_imports = { + os.path.splitext(os.path.basename(f))[0] + for f in files_processed_and_queued + } + for entrypoint in entrypoint_files: + existing_relative_imports = get_all_relative_imports( + os.path.join(folder, entrypoint), + ) + # Add in self so we don't create a circular import + existing_relative_imports.add(os.path.splitext(entrypoint)[0]) + missing_relative_imports = all_relative_imports - existing_relative_imports + add_relative_imports( + os.path.join(folder, entrypoint), + missing_relative_imports, + )