From 1881f2fd3f599ad539affcd20b803dc85be1137d Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sun, 26 May 2024 23:59:58 -0700 Subject: [PATCH 1/5] remove the dummy imports --- llmfoundry/models/mpt/modeling_mpt.py | 27 ++---- llmfoundry/utils/huggingface_hub_utils.py | 112 +++++++++++++++++++++- 2 files changed, 118 insertions(+), 21 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index d51558f04d..6f7361f27f 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -40,6 +40,8 @@ except Exception as e: raise e +import logging + from transformers import PreTrainedModel, PreTrainedTokenizerBase from transformers.modeling_outputs import ( BaseModelOutputWithPast, @@ -62,32 +64,17 @@ from llmfoundry.models.layers.custom_embedding import SharedEmbedding from llmfoundry.models.layers.layer_builders import build_norm from llmfoundry.models.mpt.configuration_mpt import MPTConfig +from llmfoundry.models.utils.act_ckpt import ( + build_act_ckpt_mod_to_blocks, + check_mapping_blocks_overlap, + pass_on_block_idx, +) from llmfoundry.models.utils.config_moe_args import config_moe_args from llmfoundry.models.utils.mpt_param_count import ( mpt_get_active_params, mpt_get_total_params, ) -# NOTE: All utils are imported directly even if unused so that -# HuggingFace can detect all the needed files to copy into its modules folder. -# Otherwise, certain modules are missing. -# isort: off -from llmfoundry.models.utils.meta_init_context import \ - init_empty_weights # type: ignore (see note) -from llmfoundry.models.utils.param_init_fns import ( - generic_param_init_fn_, # type: ignore (see note) -) -from llmfoundry.models.layers.ffn import resolve_ffn_act_fn # type: ignore (see note) -from llmfoundry.models.layers.fc import fcs # type: ignore (see note) - -from llmfoundry.models.utils.act_ckpt import ( - pass_on_block_idx, - build_act_ckpt_mod_to_blocks, - check_mapping_blocks_overlap, -) - -import logging - log = logging.getLogger(__name__) diff --git a/llmfoundry/utils/huggingface_hub_utils.py b/llmfoundry/utils/huggingface_hub_utils.py index b4eec89cbd..ea98be2ced 100644 --- a/llmfoundry/utils/huggingface_hub_utils.py +++ b/llmfoundry/utils/huggingface_hub_utils.py @@ -3,6 +3,7 @@ import ast import importlib +import json import os from typing import Optional, Sequence @@ -139,6 +140,80 @@ def process_file( return new_files_to_process +def get_all_relative_imports(file_path: str) -> set[str]: + """Get all relative imports from a file. + + Args: + file_path (str): The file to process. + + Returns: + set[str]: The relative imports. + """ + with open(file_path, 'r', encoding='utf-8') as f: + source = f.read() + + tree = ast.parse(source) + relative_imports = set() + for node in ast.walk(tree): + if isinstance( + node, + ast.ImportFrom, + ) and node.module is not None and node.level == 1: + relative_imports.add(node.module) + + return relative_imports + + +def add_relative_imports( + file_path: str, + relative_imports_to_add: set[str], +) -> None: + """Add relative imports to a file. + + Args: + file_path (str): The file to add to. + relative_imports_to_add (set[str]): The set of relative imports to add + """ + # Get the directory name where all the files are located + dir_name = os.path.dirname(file_path) + + with open(file_path, 'r', encoding='utf-8') as f: + source = f.read() + + tree = ast.parse(source) + + for relative_import in relative_imports_to_add: + import_path = os.path.join(dir_name, relative_import + '.py') + # Open up the file we are adding an import to find something to import from it + with open(import_path, 'r', encoding='utf-8') as f: + import_source = f.read() + import_tree = ast.parse(import_source) + item_to_import = None + for node in ast.walk(import_tree): + # Find the first function or class + if isinstance(node, + ast.FunctionDef) or isinstance(node, ast.ClassDef): + # Get its name to import it + item_to_import = node.name + break + + if item_to_import is None: + item_to_import = '*' + + # This will look like `from .relative_import import item_to_import` + import_node = ast.ImportFrom( + module=relative_import, + names=[ast.alias(name=item_to_import, asname=None)], + level=1, + ) + + # Insert near the top of the file, but past the from __future__ import + tree.body.insert(2, import_node) + + with open(file_path, 'w', encoding='utf-8') as f: + f.write(ast.unparse(tree)) + + def edit_files_for_hf_compatibility( folder: str, flatten_imports_prefix: Sequence[str] = ('llmfoundry',), @@ -158,9 +233,27 @@ def edit_files_for_hf_compatibility( remove_imports_prefix (Sequence[str], optional): Sequence of prefixes to remove. Takes precedence over flattening. Defaults to ('composer', 'omegaconf', 'llmfoundry.metrics', 'llmfoundry.utils.builders'). """ + listed_dir = os.listdir(folder) + + # Try to acquire the config file to determine which python file is the entrypoint file + config_file_exists = 'config.json' in listed_dir + with open(os.path.join(folder, 'config.json'), 'r') as _f: + config = json.load(_f) + + # If the config file exists, the entrypoint files are specified in the auto map + entrypoint_files = set() + if config_file_exists: + for key, value in config['auto_map'].items(): + # Only keep the modeling entrypoints, e.g. AutoModelForCausalLM + if 'model' not in key.lower(): + continue + split_path = value.split('.') + if len(split_path) > 1: + entrypoint_files.add(split_path[0] + '.py') + files_to_process = [ os.path.join(folder, filename) - for filename in os.listdir(folder) + for filename in listed_dir if filename.endswith('.py') ] files_processed_and_queued = set(files_to_process) @@ -178,3 +271,20 @@ def edit_files_for_hf_compatibility( if file not in files_processed_and_queued: files_to_process.append(file) files_processed_and_queued.add(file) + + # For each entrypoint, determine which imports are missing, and add them + all_relative_imports = { + os.path.splitext(os.path.basename(f))[0] + for f in files_processed_and_queued + } + for entrypoint in entrypoint_files: + existing_relative_imports = get_all_relative_imports( + os.path.join(folder, entrypoint), + ) + # Add in self so we don't create a circular import + existing_relative_imports.add(os.path.splitext(entrypoint)[0]) + missing_relative_imports = all_relative_imports - existing_relative_imports + add_relative_imports( + os.path.join(folder, entrypoint), + missing_relative_imports, + ) From d80087f131d349d14d6d589e4dd8f6a2455d6204 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Mon, 27 May 2024 00:02:17 -0700 Subject: [PATCH 2/5] more comment --- llmfoundry/utils/huggingface_hub_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llmfoundry/utils/huggingface_hub_utils.py b/llmfoundry/utils/huggingface_hub_utils.py index ea98be2ced..5f0eaf0e79 100644 --- a/llmfoundry/utils/huggingface_hub_utils.py +++ b/llmfoundry/utils/huggingface_hub_utils.py @@ -273,6 +273,8 @@ def edit_files_for_hf_compatibility( files_processed_and_queued.add(file) # For each entrypoint, determine which imports are missing, and add them + # This is because HF does not recursively search imports when determining + # which files to copy into its modules cache all_relative_imports = { os.path.splitext(os.path.basename(f))[0] for f in files_processed_and_queued From 1507c8b3503458c0842e1880bede562f7d9bfc67 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Mon, 27 May 2024 00:32:02 -0700 Subject: [PATCH 3/5] add fcs back --- llmfoundry/models/mpt/modeling_mpt.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 6f7361f27f..63f76228e8 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -75,6 +75,12 @@ mpt_get_total_params, ) +# Import the fcs here so that recursive code creating the files for hf checkpoints can find them +# This is the only exception because fc.py is not imported in any other place in the codebase +# isort: off +from llmfoundry.models.layers.fc import fcs # type: ignore +# isort: on + log = logging.getLogger(__name__) From 76921d8a7097cef19dea85d560f94f12c0420d4a Mon Sep 17 00:00:00 2001 From: Daniel King Date: Mon, 27 May 2024 00:49:55 -0700 Subject: [PATCH 4/5] add param init fns back --- llmfoundry/models/mpt/modeling_mpt.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 63f76228e8..9d18799e93 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -75,10 +75,11 @@ mpt_get_total_params, ) -# Import the fcs here so that recursive code creating the files for hf checkpoints can find them -# This is the only exception because fc.py is not imported in any other place in the codebase +# Import the fcs and param_init_fns here so that the recursive code creating the files for hf checkpoints can find them +# These are the exceptions because fc.py and param_init_fns.py are not imported in any other place in the import tree # isort: off from llmfoundry.models.layers.fc import fcs # type: ignore +from llmfoundry.models.utils.param_init_fns import generic_param_init_fn_ # type: ignore # isort: on log = logging.getLogger(__name__) From d79ba7c03bcdb70ebb7bfa1a264596a9fc9c0cea Mon Sep 17 00:00:00 2001 From: Daniel King Date: Mon, 27 May 2024 00:54:57 -0700 Subject: [PATCH 5/5] more resilient --- llmfoundry/utils/huggingface_hub_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/utils/huggingface_hub_utils.py b/llmfoundry/utils/huggingface_hub_utils.py index 5f0eaf0e79..3f7b3a0f55 100644 --- a/llmfoundry/utils/huggingface_hub_utils.py +++ b/llmfoundry/utils/huggingface_hub_utils.py @@ -240,10 +240,10 @@ def edit_files_for_hf_compatibility( with open(os.path.join(folder, 'config.json'), 'r') as _f: config = json.load(_f) - # If the config file exists, the entrypoint files are specified in the auto map + # If the config file exists, the entrypoint files would be specified in the auto map entrypoint_files = set() if config_file_exists: - for key, value in config['auto_map'].items(): + for key, value in config.get('auto_map', {}).items(): # Only keep the modeling entrypoints, e.g. AutoModelForCausalLM if 'model' not in key.lower(): continue