diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index c79537c781..491d510188 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -8,7 +8,7 @@ import os import tempfile from pathlib import Path -from typing import Optional, Union +from typing import Optional, Sequence, Union import torch from composer.core import Callback, Event, State, Time, TimeUnit @@ -32,31 +32,40 @@ class HuggingFaceCheckpointer(Callback): """Save a huggingface formatted checkpoint during training. Args: - save_folder (str): Top level folder to save checkpoints to (can be a URI). It is likely that - this would be the same as your save_folder. - save_interval: Union[str, int, Time]: The interval describing how often checkpoints should be - saved. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`. - Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`, + save_folder (str): Top level folder to save checkpoints to (can be a + URI). It is likely that this would be the same as your save_folder. + save_interval: Union[str, int, Time]: The interval describing how often + checkpoints should be saved. If an integer, it will be assumed to be + in :attr:`.TimeUnit.EPOCH`. Otherwise, the unit must be either + :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`, :attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`. - huggingface_folder_name (str): Folder to save each checkpoint under (can be a format string). Default is ``ba{batch}``. - precision: The precision to save the model in. Default is ``float32``. Options are ``bfloat16``, ``float16``, or ``float32``. + huggingface_folder_name (str): Folder to save each checkpoint under (can + be a format string). Default is ``ba{batch}``. + precision: The precision to save the model in. Default is ``float32``. + Options are ``bfloat16``, ``float16``, or ``float32``. overwrite (bool): Whether to overwrite previous checkpoints. - mlflow_registered_model_name (Optional[str]): The name to register the model under in the MLflow model registry. If ``None``, the model will not - be registered. Default is ``None``. - mlflow_logging_config (Optional[dict]): A dictionary of config arguments that will get passed along to the MLflow ``save_model`` call. - Expected to contain ``metadata`` and ``task`` keys. If either is unspecified, the defaults are ``'text-generation'`` and + mlflow_registered_model_name (Optional[str]): The name to register the + model under in the MLflow model registry. If ``None``, the model + will not be registered. Default is ``None``. + mlflow_logging_config (Optional[dict]): A dictionary of config arguments + that will get passed along to the MLflow ``save_model`` call. + Expected to contain ``metadata`` and ``task`` keys. If either is + unspecified, the defaults are ``'text-generation'`` and ``{'task': 'llm/v1/completions'}`` respectively. + flatten_imports (Sequence[str]): A sequence of import prefixes that will + be flattened when editing MPT files. """ def __init__( - self, - save_folder: str, - save_interval: Union[str, int, Time], - huggingface_folder_name: str = 'ba{batch}', - precision: str = 'float32', - overwrite: bool = True, - mlflow_registered_model_name: Optional[str] = None, - mlflow_logging_config: Optional[dict] = None, + self, + save_folder: str, + save_interval: Union[str, int, Time], + huggingface_folder_name: str = 'ba{batch}', + precision: str = 'float32', + overwrite: bool = True, + mlflow_registered_model_name: Optional[str] = None, + mlflow_logging_config: Optional[dict] = None, + flatten_imports: Sequence[str] = ('llmfoundry',), ): _, _, self.save_dir_format_str = parse_uri(save_folder) self.overwrite = overwrite @@ -66,6 +75,7 @@ def __init__( 'float16': torch.float16, 'bfloat16': torch.bfloat16, }[precision] + self.flatten_imports = flatten_imports # mlflow config setup self.mlflow_registered_model_name = mlflow_registered_model_name @@ -91,7 +101,7 @@ def __init__( if isinstance(save_interval, int): save_interval = Time(save_interval, TimeUnit.EPOCH) - self.save_interval = save_interval + self.save_interval: Time = save_interval self.check_interval = create_interval_scheduler( save_interval, include_end_of_training=True) self.remote_ud = maybe_create_remote_uploader_downloader_from_uri( @@ -229,7 +239,10 @@ def _save_checkpoint(self, state: State, logger: Logger): # Only need to edit files for MPT because it has custom code if original_model.config.model_type == 'mpt': log.debug('Editing MPT files for HuggingFace compatibility') - edit_files_for_hf_compatibility(temp_save_dir) + edit_files_for_hf_compatibility( + temp_save_dir, + self.flatten_imports, + ) if self.remote_ud is not None: log.info(f'Uploading HuggingFace formatted checkpoint') diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 8c134e2b9f..4c80b10ed9 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -63,7 +63,7 @@ # Otherwise, certain modules are missing. # isort: off from llmfoundry.models.utils.adapt_tokenizer import ( - AutoTokenizerForMOD, # type: ignore (see note), + AutoTokenizerForMOD, # type: ignore (see note) adapt_tokenizer_for_denoising, # type: ignore (see note) ) from llmfoundry.models.utils.hf_prefixlm_converter import ( diff --git a/llmfoundry/utils/huggingface_hub_utils.py b/llmfoundry/utils/huggingface_hub_utils.py index 47d7f79bff..a74ab1cc35 100644 --- a/llmfoundry/utils/huggingface_hub_utils.py +++ b/llmfoundry/utils/huggingface_hub_utils.py @@ -4,14 +4,14 @@ import ast import importlib import os -from typing import List, Optional +from typing import Optional, Sequence __all__ = ['edit_files_for_hf_compatibility'] class DeleteSpecificNodes(ast.NodeTransformer): - def __init__(self, nodes_to_remove: List[ast.AST]): + def __init__(self, nodes_to_remove: list[ast.AST]): self.nodes_to_remove = nodes_to_remove def visit(self, node: ast.AST) -> Optional[ast.AST]: @@ -39,7 +39,26 @@ def find_module_file(module_name: str) -> str: return module_file -def process_file(file_path: str, folder_path: str) -> List[str]: +def _flatten_import( + node: ast.ImportFrom, + flatten_imports_prefix: Sequence[str], +) -> bool: + """Returns True if import should be flattened. + + Checks whether the node starts the same as any of the imports in + flatten_imports_prefix. + """ + for import_prefix in flatten_imports_prefix: + if node.module is not None and node.module.startswith(import_prefix): + return True + return False + + +def process_file( + file_path: str, + folder_path: str, + flatten_imports_prefix: Sequence[str], +) -> list[str]: with open(file_path, 'r') as f: source = f.read() @@ -51,37 +70,35 @@ def process_file(file_path: str, folder_path: str) -> List[str]: new_files_to_process = [] nodes_to_remove = [] for node in ast.walk(tree): - # convert any llmfoundry imports into relative imports - if isinstance( - node, ast.ImportFrom - ) and node.module is not None and node.module.startswith('llmfoundry'): + # Convert any llmfoundry imports into relative imports + if (isinstance(node, ast.ImportFrom) and node.module is not None and + _flatten_import(node, flatten_imports_prefix)): module_path = find_module_file(node.module) node.module = convert_to_relative_import(node.module, parent_module_name) - # recursively process any llmfoundry files + # Recursively process any llmfoundry files new_files_to_process.append(module_path) - # remove any imports from composer or omegaconf + # Remove any imports from composer or omegaconf elif isinstance(node, ast.ImportFrom) and node.module is not None and ( node.module.startswith('composer') or node.module.startswith('omegaconf')): nodes_to_remove.append(node) - # remove the Composer* class - elif isinstance(node, - ast.ClassDef) and node.name.startswith('Composer'): + # Remove the Composer* class + elif (isinstance(node, ast.ClassDef) and + node.name.startswith('Composer')): nodes_to_remove.append(node) - # remove the __all__ declaration in any __init__.py files, whose enclosing module - # will be converted to a single file of the same name - elif isinstance(node, - ast.Assign) and len(node.targets) == 1 and isinstance( - node.targets[0], - ast.Name) and node.targets[0].id == '__all__': + # Remove the __all__ declaration in any __init__.py files, whose + # enclosing module will be converted to a single file of the same name + elif (isinstance(node, ast.Assign) and len(node.targets) == 1 and + isinstance(node.targets[0], ast.Name) and + node.targets[0].id == '__all__'): nodes_to_remove.append(node) transformer = DeleteSpecificNodes(nodes_to_remove) new_tree = transformer.visit(tree) new_filename = os.path.basename(file_path) - # special case for __init__.py to mimic the original submodule + # Special case for __init__.py to mimic the original submodule if new_filename == '__init__.py': new_filename = file_path.split('/')[-2] + '.py' new_file_path = os.path.join(folder_path, new_filename) @@ -92,7 +109,10 @@ def process_file(file_path: str, folder_path: str) -> List[str]: return new_files_to_process -def edit_files_for_hf_compatibility(folder: str) -> None: +def edit_files_for_hf_compatibility( + folder: str, + flatten_imports_prefix: Sequence[str] = ('llmfoundry',), +) -> None: files_to_process = [ os.path.join(folder, filename) for filename in os.listdir(folder) @@ -103,7 +123,7 @@ def edit_files_for_hf_compatibility(folder: str) -> None: while len(files_to_process) > 0: to_process = files_to_process.pop() if os.path.isfile(to_process) and to_process.endswith('.py'): - to_add = process_file(to_process, folder) + to_add = process_file(to_process, folder, flatten_imports_prefix) for file in to_add: if file not in files_processed_and_queued: files_to_process.append(file) diff --git a/tests/utils/test_huggingface_hub_utils.py b/tests/utils/test_huggingface_hub_utils.py new file mode 100644 index 0000000000..5effb3a771 --- /dev/null +++ b/tests/utils/test_huggingface_hub_utils.py @@ -0,0 +1,16 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import ast + +from llmfoundry.utils.huggingface_hub_utils import _flatten_import + + +def test_flatten_import_true(): + node = ast.ImportFrom('y', ['x', 'y', 'z']) + assert _flatten_import(node, ('x', 'y', 'z')) + + +def test_flatten_import_false(): + node = ast.ImportFrom('y', ['x', 'y', 'z']) + assert not _flatten_import(node, ('x', 'z'))