Skip to content

Commit

Permalink
Make HF conversion automatically add missing imports (#1241)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored May 29, 2024
1 parent 43d149b commit b82a82b
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 19 deletions.
30 changes: 12 additions & 18 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
except Exception as e:
raise e

import logging

from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
Expand All @@ -62,31 +64,23 @@
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.layer_builders import build_norm
from llmfoundry.models.mpt.configuration_mpt import MPTConfig
from llmfoundry.models.utils.act_ckpt import (
build_act_ckpt_mod_to_blocks,
check_mapping_blocks_overlap,
pass_on_block_idx,
)
from llmfoundry.models.utils.config_moe_args import config_moe_args
from llmfoundry.models.utils.mpt_param_count import (
mpt_get_active_params,
mpt_get_total_params,
)

# NOTE: All utils are imported directly even if unused so that
# HuggingFace can detect all the needed files to copy into its modules folder.
# Otherwise, certain modules are missing.
# Import the fcs and param_init_fns here so that the recursive code creating the files for hf checkpoints can find them
# These are the exceptions because fc.py and param_init_fns.py are not imported in any other place in the import tree
# isort: off
from llmfoundry.models.utils.meta_init_context import \
init_empty_weights # type: ignore (see note)
from llmfoundry.models.utils.param_init_fns import (
generic_param_init_fn_, # type: ignore (see note)
)
from llmfoundry.models.layers.ffn import resolve_ffn_act_fn # type: ignore (see note)
from llmfoundry.models.layers.fc import fcs # type: ignore (see note)

from llmfoundry.models.utils.act_ckpt import (
pass_on_block_idx,
build_act_ckpt_mod_to_blocks,
check_mapping_blocks_overlap,
)

import logging
from llmfoundry.models.layers.fc import fcs # type: ignore
from llmfoundry.models.utils.param_init_fns import generic_param_init_fn_ # type: ignore
# isort: on

log = logging.getLogger(__name__)

Expand Down
114 changes: 113 additions & 1 deletion llmfoundry/utils/huggingface_hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import ast
import importlib
import json
import os
from typing import Optional, Sequence

Expand Down Expand Up @@ -139,6 +140,80 @@ def process_file(
return new_files_to_process


def get_all_relative_imports(file_path: str) -> set[str]:
"""Get all relative imports from a file.
Args:
file_path (str): The file to process.
Returns:
set[str]: The relative imports.
"""
with open(file_path, 'r', encoding='utf-8') as f:
source = f.read()

tree = ast.parse(source)
relative_imports = set()
for node in ast.walk(tree):
if isinstance(
node,
ast.ImportFrom,
) and node.module is not None and node.level == 1:
relative_imports.add(node.module)

return relative_imports


def add_relative_imports(
file_path: str,
relative_imports_to_add: set[str],
) -> None:
"""Add relative imports to a file.
Args:
file_path (str): The file to add to.
relative_imports_to_add (set[str]): The set of relative imports to add
"""
# Get the directory name where all the files are located
dir_name = os.path.dirname(file_path)

with open(file_path, 'r', encoding='utf-8') as f:
source = f.read()

tree = ast.parse(source)

for relative_import in relative_imports_to_add:
import_path = os.path.join(dir_name, relative_import + '.py')
# Open up the file we are adding an import to find something to import from it
with open(import_path, 'r', encoding='utf-8') as f:
import_source = f.read()
import_tree = ast.parse(import_source)
item_to_import = None
for node in ast.walk(import_tree):
# Find the first function or class
if isinstance(node,
ast.FunctionDef) or isinstance(node, ast.ClassDef):
# Get its name to import it
item_to_import = node.name
break

if item_to_import is None:
item_to_import = '*'

# This will look like `from .relative_import import item_to_import`
import_node = ast.ImportFrom(
module=relative_import,
names=[ast.alias(name=item_to_import, asname=None)],
level=1,
)

# Insert near the top of the file, but past the from __future__ import
tree.body.insert(2, import_node)

with open(file_path, 'w', encoding='utf-8') as f:
f.write(ast.unparse(tree))


def edit_files_for_hf_compatibility(
folder: str,
flatten_imports_prefix: Sequence[str] = ('llmfoundry',),
Expand All @@ -158,9 +233,27 @@ def edit_files_for_hf_compatibility(
remove_imports_prefix (Sequence[str], optional): Sequence of prefixes to remove. Takes precedence over flattening.
Defaults to ('composer', 'omegaconf', 'llmfoundry.metrics', 'llmfoundry.utils.builders').
"""
listed_dir = os.listdir(folder)

# Try to acquire the config file to determine which python file is the entrypoint file
config_file_exists = 'config.json' in listed_dir
with open(os.path.join(folder, 'config.json'), 'r') as _f:
config = json.load(_f)

# If the config file exists, the entrypoint files would be specified in the auto map
entrypoint_files = set()
if config_file_exists:
for key, value in config.get('auto_map', {}).items():
# Only keep the modeling entrypoints, e.g. AutoModelForCausalLM
if 'model' not in key.lower():
continue
split_path = value.split('.')
if len(split_path) > 1:
entrypoint_files.add(split_path[0] + '.py')

files_to_process = [
os.path.join(folder, filename)
for filename in os.listdir(folder)
for filename in listed_dir
if filename.endswith('.py')
]
files_processed_and_queued = set(files_to_process)
Expand All @@ -178,3 +271,22 @@ def edit_files_for_hf_compatibility(
if file not in files_processed_and_queued:
files_to_process.append(file)
files_processed_and_queued.add(file)

# For each entrypoint, determine which imports are missing, and add them
# This is because HF does not recursively search imports when determining
# which files to copy into its modules cache
all_relative_imports = {
os.path.splitext(os.path.basename(f))[0]
for f in files_processed_and_queued
}
for entrypoint in entrypoint_files:
existing_relative_imports = get_all_relative_imports(
os.path.join(folder, entrypoint),
)
# Add in self so we don't create a circular import
existing_relative_imports.add(os.path.splitext(entrypoint)[0])
missing_relative_imports = all_relative_imports - existing_relative_imports
add_relative_imports(
os.path.join(folder, entrypoint),
missing_relative_imports,
)

0 comments on commit b82a82b

Please sign in to comment.