From 23c31734b3c84d104e7dbb3643b77cdc8a8b88ec Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Wed, 3 Apr 2024 16:39:37 -0700 Subject: [PATCH] Background mlflow model registration (#1078) --- llmfoundry/callbacks/hf_checkpointer.py | 370 +++++++++++------- .../inference/test_convert_composer_to_hf.py | 26 +- 2 files changed, 245 insertions(+), 151 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 9eb23f1030..0d57feaef4 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -7,9 +7,12 @@ import math import os import re +import shutil import tempfile +import time +from multiprocessing.context import SpawnProcess from pathlib import Path -from typing import Any, Dict, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Union import torch from composer.core import Callback, Event, State, Time, TimeUnit @@ -72,6 +75,31 @@ def _maybe_get_license_filename( return None +def _register_model_with_run_id_multiprocess( + mlflow_logger: MLFlowLogger, + composer_logging_level: int, + model_uri: str, + name: str, + await_creation_for: int, +): + """Call MLFlowLogger.register_model_with_run_id. + + Used mainly to register from a child process. + """ + # Setup logging for child process. This ensures that any logs from composer are surfaced. + if composer_logging_level > 0: + # If logging_level is 0, then the composer logger was unset. + logging.basicConfig( + format= + f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s' + ) + logging.getLogger('composer').setLevel(composer_logging_level) + + # Register model. + mlflow_logger.register_model_with_run_id( + model_uri=model_uri, name=name, await_creation_for=await_creation_for) + + class HuggingFaceCheckpointer(Callback): """Save a huggingface formatted checkpoint during training. @@ -171,6 +199,10 @@ def __init__( self.last_checkpoint_batch: Optional[Time] = None self.mlflow_loggers = [] + self.child_processes: List[SpawnProcess] = [] + # Temporary save directory used by child_processes. + self.temp_save_dir = None + def run_event(self, event: Event, state: State, logger: Logger) -> None: # The interval scheduler handles only returning True for the appropriate events if state.get_elapsed_duration() is not None and self.check_interval( @@ -202,6 +234,21 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: import mlflow mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set( '5GB') + elif event == Event.FIT_END: + # Wait for all child processes spawned by the callback to finish. + timeout = 3600 + wait_start = time.time() + while not self._all_child_processes_done(): + wait_time = time.time() - wait_start + if wait_time > timeout: + raise TimeoutError( + f'Waited {wait_time} seconds for child processes to complete. Exceeded timeout of {timeout} seconds.' + ) + time.sleep(2) + + # Clean up temporary save directory; all processes are done with it. + if self.temp_save_dir is not None: + shutil.rmtree(self.temp_save_dir) def _is_last_batch(self, state: State): elapsed_duration = state.get_elapsed_duration() @@ -218,6 +265,12 @@ def _is_last_batch(self, state: State): return False + def _all_child_processes_done(self) -> bool: + not_done = any(process.is_alive() for process in self.child_processes) + x = torch.tensor(1 if not_done else 0).to(device='cuda') + dist.all_reduce(x, reduce_operation='MAX') + return x.item() == 0 + def _save_checkpoint(self, state: State, logger: Logger): del logger # unused @@ -235,158 +288,175 @@ def _save_checkpoint(self, state: State, logger: Logger): Path(self.save_dir_format_str) / self.huggingface_folder_name_fstr), state.run_name, state.timestamp) - dir_context_mgr = tempfile.TemporaryDirectory( - ) if self.remote_ud is not None else contextlib.nullcontext( - enter_result=save_dir) - - with dir_context_mgr as temp_save_dir: - assert isinstance(temp_save_dir, - str) # pyright doesn't know about enter_result - - log.debug('Gathering state dict') - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - - if state.is_model_ddp: - composer_model = state.model.module - original_model: PreTrainedModel = state.model.module.model - state_dict_model = state.model.module.model - original_tokenizer = state.model.module.tokenizer - elif isinstance(state.model.model, FSDP): - composer_model = state.model - original_model: PreTrainedModel = state.model.model.module - state_dict_model = state.model.model - original_tokenizer = state.model.tokenizer + + # Use a temporary directory if save_dir is remote. + use_temp_dir = self.remote_ud is not None + temp_save_dir = tempfile.mkdtemp() if use_temp_dir else save_dir + + log.debug('Gathering state dict') + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + + if state.is_model_ddp: + composer_model = state.model.module + original_model: PreTrainedModel = state.model.module.model + state_dict_model = state.model.module.model + original_tokenizer = state.model.module.tokenizer + elif isinstance(state.model.model, FSDP): + composer_model = state.model + original_model: PreTrainedModel = state.model.model.module + state_dict_model = state.model.model + original_tokenizer = state.model.tokenizer + else: + composer_model = state.model + original_model: PreTrainedModel = state.model.model + state_dict_model = state.model.model + original_tokenizer = state.model.tokenizer + + state_dict_context = fsdp_state_dict_type_context( + original_model, + state_dict_type='full') if ((not state.is_model_ddp) and isinstance( + state_dict_model, FSDP)) else contextlib.nullcontext() + + with state_dict_context: + state_dict = state_dict_model.state_dict() + + # convert the state dict to the requested precision + for k, v in state_dict.items(): + if isinstance(v, torch.Tensor): + state_dict[k] = v.to(dtype=self.dtype) + + new_model_instance = None # Need this for pyright because variable could be unbound + + if dist.get_global_rank() == 0: + log.debug('Saving Hugging Face checkpoint in global rank 0') + + copied_config = copy.deepcopy(original_model.config) + if copied_config.model_type == 'mpt': + copied_config.attn_config['attn_impl'] = 'torch' + copied_config.init_device = 'cpu' + + log.debug(f'Creating new model instance') + + if composer_model.using_peft: + # We don't use meta here because the state dict does not contain the full + # model, only the adapter weights. + active_adapter = original_model.active_adapter + base_model = original_model.get_base_model() + new_base_model_instance = type(base_model)(copied_config) + + new_model_instance = type(original_model)( + new_base_model_instance, + original_model.peft_config[active_adapter]) + new_model_instance.to(dtype=self.dtype) else: - composer_model = state.model - original_model: PreTrainedModel = state.model.model - state_dict_model = state.model.model - original_tokenizer = state.model.tokenizer - - state_dict_context = fsdp_state_dict_type_context( - original_model, state_dict_type='full') if ( - (not state.is_model_ddp) and isinstance( - state_dict_model, FSDP)) else contextlib.nullcontext() - - with state_dict_context: - state_dict = state_dict_model.state_dict() - - # convert the state dict to the requested precision - for k, v in state_dict.items(): - if isinstance(v, torch.Tensor): - state_dict[k] = v.to(dtype=self.dtype) - - if dist.get_global_rank() == 0: - log.debug('Saving Hugging Face checkpoint in global rank 0') - - copied_config = copy.deepcopy(original_model.config) - if copied_config.model_type == 'mpt': - copied_config.attn_config['attn_impl'] = 'torch' - copied_config.init_device = 'cpu' - - log.debug(f'Creating new model instance') - - if composer_model.using_peft: - # We don't use meta here because the state dict does not contain the full - # model, only the adapter weights. - active_adapter = original_model.active_adapter - base_model = original_model.get_base_model() - new_base_model_instance = type(base_model)(copied_config) - - new_model_instance = type(original_model)( - new_base_model_instance, - original_model.peft_config[active_adapter]) - new_model_instance.to(dtype=self.dtype) - else: - # First create the model instance on meta device to avoid the - # initialization cost. - with init_empty_weights(): - new_model_instance = type(original_model)(copied_config) - - # Then load the state dict in with "assign" so that the state dict - # is loaded properly even though the model is initially on meta device. - new_model_instance.load_state_dict(state_dict, assign=True) - del state_dict - - log.debug('Saving Hugging Face checkpoint to disk') - new_model_instance.save_pretrained(temp_save_dir) - if original_tokenizer is not None: - assert isinstance(original_tokenizer, - PreTrainedTokenizerBase) - original_tokenizer.save_pretrained(temp_save_dir) - - # Only need to edit files for MPT because it has custom code - if original_model.config.model_type == 'mpt': - log.debug('Editing MPT files for HuggingFace compatibility') - edit_files_for_hf_compatibility( - temp_save_dir, - self.flatten_imports, + # First create the model instance on meta device to avoid the + # initialization cost. + with init_empty_weights(): + new_model_instance = type(original_model)(copied_config) + + # Then load the state dict in with "assign" so that the state dict + # is loaded properly even though the model is initially on meta device. + new_model_instance.load_state_dict(state_dict, assign=True) + del state_dict + + log.debug('Saving Hugging Face checkpoint to disk') + new_model_instance.save_pretrained(temp_save_dir) + if original_tokenizer is not None: + assert isinstance(original_tokenizer, PreTrainedTokenizerBase) + original_tokenizer.save_pretrained(temp_save_dir) + + # Only need to edit files for MPT because it has custom code + if original_model.config.model_type == 'mpt': + log.debug('Editing MPT files for HuggingFace compatibility') + edit_files_for_hf_compatibility( + temp_save_dir, + self.flatten_imports, + ) + + if self.remote_ud is not None: + for filename in os.listdir(temp_save_dir): + remote_file_name = os.path.join(save_dir, filename) + remote_file_uri = self.remote_ud.remote_backend.get_uri( + remote_file_name) + log.info( + f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}' + ) + self.remote_ud.upload_file( + state=state, + remote_file_name=remote_file_name, + file_path=Path(os.path.join(temp_save_dir, filename)), + overwrite=self.overwrite, ) - if self.remote_ud is not None: - for filename in os.listdir(temp_save_dir): - remote_file_name = os.path.join(save_dir, filename) - remote_file_uri = self.remote_ud.remote_backend.get_uri( - remote_file_name) - log.info( - f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}' - ) - self.remote_ud.upload_file( - state=state, - remote_file_name=remote_file_name, - file_path=Path(os.path.join(temp_save_dir, - filename)), - overwrite=self.overwrite, - ) + dist.barrier() - if self.mlflow_registered_model_name and self._is_last_batch( - state): - components = {'model': new_model_instance} - if original_tokenizer is not None: - components['tokenizer'] = original_tokenizer + if dist.get_global_rank() == 0: + if self.mlflow_registered_model_name and self._is_last_batch(state): + components = {'model': new_model_instance} + if original_tokenizer is not None: + components['tokenizer'] = original_tokenizer - log.debug('Logging Hugging Face model to MLFlow') - for i, mlflow_logger in enumerate(self.mlflow_loggers): - log.debug( - f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}' - ) - local_save_path = str( - Path(temp_save_dir) / f'mlflow_save_{i}') - - # TODO: Remove after mlflow fixes the bug that makes this necessary - import mlflow - mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: '' - model_saving_kwargs: Dict[str, Any] = { - 'path': local_save_path - } - if composer_model.using_peft: - model_saving_kwargs['flavor'] = 'peft' - model_saving_kwargs[ - 'save_pretrained_dir'] = temp_save_dir - model_saving_kwargs[ - 'metadata'] = self.mlflow_logging_config[ - 'metadata'] - else: - model_saving_kwargs['flavor'] = 'transformers' - model_saving_kwargs[ - 'transformers_model'] = components - model_saving_kwargs.update( - self.mlflow_logging_config) - - mlflow_logger.save_model(**model_saving_kwargs) - - # Upload the license file generated by mlflow during the model saving. - license_filename = _maybe_get_license_filename( - local_save_path, - self.mlflow_logging_config['metadata'].get( - 'pretrained_model_name', None)) - if license_filename is not None: - mlflow_logger._mlflow_client.log_artifact( - mlflow_logger._run_id, - os.path.join(local_save_path, license_filename), - ) - - mlflow_logger.register_model_with_run_id( - model_uri=local_save_path, - name=self.mlflow_registered_model_name, - await_creation_for=3600, + log.debug('Logging Hugging Face model to MLFlow') + for i, mlflow_logger in enumerate(self.mlflow_loggers): + log.debug( + f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}' + ) + local_save_path = str( + Path(temp_save_dir) / f'mlflow_save_{i}') + + # TODO: Remove after mlflow fixes the bug that makes this necessary + import mlflow + mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: '' + model_saving_kwargs: Dict[str, Any] = { + 'path': local_save_path + } + if composer_model.using_peft: + model_saving_kwargs['flavor'] = 'peft' + model_saving_kwargs[ + 'save_pretrained_dir'] = temp_save_dir + model_saving_kwargs[ + 'metadata'] = self.mlflow_logging_config['metadata'] + else: + model_saving_kwargs['flavor'] = 'transformers' + model_saving_kwargs['transformers_model'] = components + model_saving_kwargs.update(self.mlflow_logging_config) + + mlflow_logger.save_model(**model_saving_kwargs) + + # Upload the license file generated by mlflow during the model saving. + license_filename = _maybe_get_license_filename( + local_save_path, + self.mlflow_logging_config['metadata'].get( + 'pretrained_model_name', None)) + if license_filename is not None: + mlflow_logger._mlflow_client.log_artifact( + mlflow_logger._run_id, + os.path.join(local_save_path, license_filename), ) + + # Spawn a new process to register the model. + process = SpawnProcess( + target=_register_model_with_run_id_multiprocess, + kwargs={ + 'mlflow_logger': + mlflow_logger, + 'composer_logging_level': + logging.getLogger('composer').level, + 'model_uri': + local_save_path, + 'name': + self.mlflow_registered_model_name, + 'await_creation_for': + 3600, + }) + process.start() + self.child_processes.append(process) + + # Save the temporary directory to be cleaned up later. + if use_temp_dir: + self.temp_save_dir = temp_save_dir + else: + # Clean up the temporary directory if we don't need to register to mlflow. + if use_temp_dir: + shutil.rmtree(temp_save_dir) + dist.barrier() diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 3949c091aa..3c6d9425f7 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -7,7 +7,7 @@ import pathlib import shutil from argparse import Namespace -from typing import Any, Callable, Optional, cast +from typing import Any, Callable, Dict, Optional, cast from unittest.mock import ANY, MagicMock, patch import pytest @@ -256,12 +256,34 @@ def test_callback_inits(): assert hf_checkpointer.mlflow_logging_config['task'] == 'llm/v1/completions' +class MockSpawnProcess: + """Class for mocking `multiprocessing.context.SpawnProcess`. + + Runs `target(**kwargs)` on the main process. + + Mock classes are not picklable and therefore cannot be used with + multiprocessing, so we need to patch SpawnProcess for tests. + """ + + def __init__(self, target: Callable, kwargs: Dict[str, Any]): + self.target = target + self.kwargs = kwargs + + def start(self): + self.target(**self.kwargs) + + def is_alive(self) -> bool: + return False + + @pytest.mark.gpu @pytest.mark.parametrize('log_to_mlflow', [True, False]) @pytest.mark.parametrize( 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', [('3ba', '2ba', '4ba', 2, 2), ('1dur', '2ba', '1ep', 1, 2)]) @patch('os.cpu_count', MagicMock(return_value=1)) +@patch('llmfoundry.callbacks.hf_checkpointer.SpawnProcess', + new=MockSpawnProcess) def test_huggingface_conversion_callback_interval( tmp_path: pathlib.Path, log_to_mlflow: bool, hf_save_interval: str, save_interval: str, max_duration: str, expected_hf_checkpoints: int, @@ -406,6 +428,8 @@ def test_huggingface_conversion_callback_interval( 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', [('1ba', '1ba', '1ba', 1, 1)]) @patch('os.cpu_count', MagicMock(return_value=1)) +@patch('llmfoundry.callbacks.hf_checkpointer.SpawnProcess', + new=MockSpawnProcess) def test_huggingface_conversion_callback( model: str, tmp_path: pathlib.Path,