diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index a95b68cf28..688d8deb74 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -60,6 +60,26 @@ _LICENSE_FILE_PATTERN = re.compile(r'license(\.[a-z]+|$)', re.IGNORECASE) +from contextlib import contextmanager + + +@contextmanager +def _monitor_process_saver(mlflow_logger: MLFlowLogger): + # Save the current monitor process + if hasattr(mlflow_logger, 'monitor_process'): + original_monitor_process = mlflow_logger.monitor_process # type: ignore + mlflow_logger.monitor_process = None # type: ignore + else: + original_monitor_process = None + + try: + # Yield control back to the calling code + yield + finally: + # Restore the monitor process + if original_monitor_process is not None: + mlflow_logger.monitor_process = original_monitor_process # type: ignore + def _maybe_get_license_filename( local_dir: str, @@ -108,6 +128,91 @@ def _maybe_get_license_filename( return None +def _log_model_with_multi_process( + mlflow_logger: MLFlowLogger, + python_logging_level: int, + transformers_model: str, + artifact_path: str, + pretrained_model_name: str, + registered_model_name: Optional[str], + await_registration_for: int, + mlflow_logging_config: dict[str, Any], +): + """Call MLFlowLogger.log_model. + + First, patch the mlflow save_model function by removing duplicate tokenizer + files in the model directory. Then, register the model to mlflow from a + child process. + """ + # Setup logging for child process. This ensures that any logs from composer are surfaced. + if python_logging_level > 0: + # If logging_level is 0, then the composer logger was unset. + logging.basicConfig( + format= + f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s', + force=True, + ) + logging.getLogger('composer').setLevel(python_logging_level) + logging.getLogger('llmfoundry').setLevel(python_logging_level) + + import mlflow + original_save_model = mlflow.transformers.save_model + + def save_model_patch(*args: Any, **kwargs: Any): + original_save_model(*args, **kwargs) + tokenizer_files = [] + save_path = kwargs['path'] + tokenizer_path = os.path.join(save_path, 'components', 'tokenizer') + if os.path.exists(tokenizer_path): + tokenizer_files = os.listdir( + os.path.join(save_path, 'components', 'tokenizer'), + ) + try: + # Check if there are duplicate tokenizer files in the model directory and remove them. + for tokenizer_file_name in tokenizer_files: + dupe_file = os.path.isfile( + os.path.join(save_path, 'model', tokenizer_file_name), + ) + if dupe_file: + log.debug( + f'Removing duplicate tokenizer file: {tokenizer_file_name}', + ) + os.remove( + os.path.join(save_path, 'model', tokenizer_file_name), + ) + license_filename = _maybe_get_license_filename( + save_path, + pretrained_model_name, + ) + if license_filename is not None: + mlflow_logger._mlflow_client.log_artifact( + mlflow_logger._run_id, + os.path.join(save_path, license_filename), + ) + except Exception as e: + log.error( + f'Exception when removing duplicate tokenizer files in the model directory', + e, + ) + + mlflow.transformers.save_model = save_model_patch # type: ignore + + mlflow.set_tracking_uri(mlflow_logger.tracking_uri) + if mlflow_logger.model_registry_uri is not None: + mlflow.set_registry_uri(mlflow_logger.model_registry_uri) + + register_model_path = f'{mlflow_logger.model_registry_prefix}.{registered_model_name}' if mlflow_logger.model_registry_prefix and registered_model_name else registered_model_name + mlflow_logger.log_model( + transformers_model=transformers_model, + flavor='transformers', + artifact_path=artifact_path, + registered_model_name=register_model_path, + run_id=mlflow_logger._run_id, + await_registration_for=await_registration_for, + **mlflow_logging_config, + ) + + def _register_model_with_run_id_multiprocess( mlflow_logger: MLFlowLogger, composer_logging_level: int, @@ -676,102 +781,149 @@ def tensor_hook( if dist.get_global_rank() == 0: if register_to_mlflow: - new_model_instance = self.transform_model_pre_registration( - new_model_instance, - ) - - components = {'model': new_model_instance} - if original_tokenizer is not None: - components['tokenizer'] = original_tokenizer - - log.debug('Logging Hugging Face model to MLFlow') - for i, mlflow_logger in enumerate(self.mlflow_loggers): - log.debug( - f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}', - ) - local_save_path = str( - Path(temp_save_dir) / f'mlflow_save_{i}', - ) - - # TODO: Remove after mlflow fixes the bug that makes this necessary - import mlflow - import mlflow.store - mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: '' - model_saving_kwargs: dict[str, Any] = { - 'path': local_save_path, - } - if self.using_peft: - model_saving_kwargs['flavor'] = 'peft' - model_saving_kwargs['save_pretrained_dir' - ] = temp_save_dir - model_saving_kwargs[ - 'metadata'] = self.mlflow_logging_config['metadata'] - else: - model_saving_kwargs['flavor'] = 'transformers' - model_saving_kwargs['transformers_model'] = components - model_saving_kwargs.update(self.mlflow_logging_config) + if self.using_peft: - context_manager = te.onnx_export( - True, - ) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext( + # Save and register peft model to mlflow, this code path uses our older two step logic + self._save_and_register_peft_model( + state, + new_model_instance, + original_tokenizer, + temp_save_dir, ) - with context_manager: - # Add the pip requirements directly to avoid mlflow - # attempting to run inference on the model - model_saving_kwargs['pip_requirements'] = [ - 'transformers', - 'torch', - ] - mlflow_logger.save_model(**model_saving_kwargs) - - # Upload the license file generated by mlflow during the model saving. - license_filename = _maybe_get_license_filename( - local_save_path, - self.pretrained_model_name, + else: + register_save_dir = os.path.join( + temp_save_dir, + 'register_save', ) - if license_filename is not None: - mlflow_logger._mlflow_client.log_artifact( - mlflow_logger._run_id, - os.path.join(local_save_path, license_filename), - ) - - self.pre_register_edit(local_save_path,) - - # Save the monitor process to be restored after registering the model. - if hasattr(mlflow_logger, 'monitor_process'): - monitor_process = mlflow_logger.monitor_process # type: ignore - mlflow_logger.monitor_process = None # type: ignore - else: - monitor_process = None - - # Spawn a new process to register the model. - process = SpawnProcess( - target=_register_model_with_run_id_multiprocess, - kwargs={ - 'mlflow_logger': - mlflow_logger, - 'composer_logging_level': - logging.getLogger('composer').level, - 'model_uri': - local_save_path, - 'name': - self.mlflow_registered_model_name, - 'await_creation_for': - 3600, - }, + assert new_model_instance is not None + new_model_instance = self.transform_model_pre_registration( + new_model_instance, ) - process.start() - - # Restore the monitor process. - if monitor_process is not None: - mlflow_logger.monitor_process = monitor_process # type: ignore - self.register_processes.append(process) - - # Save the temporary directory to be cleaned up later. - if use_temp_dir: - self.temp_save_dir = temp_save_dir + new_model_instance.save_pretrained(register_save_dir) + if original_tokenizer: + original_tokenizer.save_pretrained(register_save_dir) + + self.pre_register_edit(register_save_dir) + + for mlflow_logger in self.mlflow_loggers: + if self.mlflow_registered_model_name: + log.debug( + f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}', + ) + + # Save the monitor process to be restored after registering the model. + with _monitor_process_saver(mlflow_logger): + process = SpawnProcess( + target=_log_model_with_multi_process, + kwargs={ + 'mlflow_logger': + mlflow_logger, + 'python_logging_level': + logging.getLogger('llmfoundry').level, + 'transformers_model': + register_save_dir, + 'artifact_path': + 'final_model_checkpoint', + 'pretrained_model_name': + self.pretrained_model_name, + 'registered_model_name': + self.mlflow_registered_model_name, + 'await_registration_for': + 3600, + 'mlflow_logging_config': + self.mlflow_logging_config, + }, + ) + + process.start() + self.register_processes.append(process) + + # Save the temporary directory to be cleaned up later. + if use_temp_dir: + self.temp_save_dir = temp_save_dir else: # Clean up the temporary directory if we don't need to register to mlflow. if use_temp_dir: shutil.rmtree(temp_save_dir) dist.barrier() + + def _save_and_register_peft_model( + self, + state: State, + new_model_instance: Any, + original_tokenizer: Optional[Any], + save_dir: str, + ): + new_model_instance = self.transform_model_pre_registration( + new_model_instance, + ) + components = {'model': new_model_instance} + if original_tokenizer is not None: + components['tokenizer'] = original_tokenizer + + log.debug('Logging Hugging Face model to MLFlow') + for i, mlflow_logger in enumerate(self.mlflow_loggers): + log.debug( + f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}', + ) + + local_save_path = str(Path(save_dir) / f'mlflow_save_{i}',) + + # TODO: Remove after mlflow fixes the bug that makes this necessary + import mlflow + import mlflow.store + mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: '' + + model_saving_kwargs: dict[str, Any] = { + 'path': local_save_path, + } + model_saving_kwargs['flavor'] = 'peft' + model_saving_kwargs['save_pretrained_dir'] = save_dir + model_saving_kwargs['metadata'] = self.mlflow_logging_config[ + 'metadata'] + + context_manager = te.onnx_export( + True, + ) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext( + ) + with context_manager: + # Add the pip requirements directly to avoid mlflow + # attempting to run inference on the model + model_saving_kwargs['pip_requirements'] = [ + 'transformers', + 'torch', + ] + mlflow_logger.save_model(**model_saving_kwargs) + + # Upload the license file generated by mlflow during the model saving. + # Get and log the license file. + license_filename = _maybe_get_license_filename( + local_save_path, + self.pretrained_model_name, + ) + if license_filename is not None: + mlflow_logger._mlflow_client.log_artifact( + mlflow_logger._run_id, + os.path.join(local_save_path, license_filename), + ) + + self.pre_register_edit(local_save_path) + + with _monitor_process_saver(mlflow_logger): + process = SpawnProcess( + target=_register_model_with_run_id_multiprocess, + kwargs={ + 'mlflow_logger': + mlflow_logger, + 'composer_logging_level': + logging.getLogger('composer').level, + 'model_uri': + local_save_path, + 'name': + self.mlflow_registered_model_name, + 'await_creation_for': + 3600, + }, + ) + process.start() + self.register_processes.append(process) diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index c25432dc48..f599ebbc16 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -13,6 +13,7 @@ from unittest.mock import ANY, MagicMock, patch import catalogue +import numpy as np import pytest import torch import torch.nn as nn @@ -341,14 +342,17 @@ def is_alive(self) -> bool: def _create_mlflow_logger_mock() -> MagicMock: mlflow_logger_mock = MagicMock(spec=MLFlowLogger) - mlflow_logger_mock.state_dict = lambda *args, **kwargs: {} - mlflow_logger_mock.save_model = MagicMock(wraps=_save_model_mock) - mlflow_logger_mock.register_model_with_run_id = MagicMock() - mlflow_logger_mock.model_registry_prefix = '' + mlflow_logger_mock._mlflow_client = MagicMock() mlflow_logger_mock._experiment_id = 'mlflow-experiment-id' mlflow_logger_mock._run_id = 'mlflow-run-id' mlflow_logger_mock._enabled = True + mlflow_logger_mock.log_model = MagicMock() + mlflow_logger_mock.model_registry_prefix = '' + mlflow_logger_mock.model_registry_uri = None + mlflow_logger_mock.state_dict = lambda *args, **kwargs: {} + mlflow_logger_mock.save_model = MagicMock(wraps=_save_model_mock) mlflow_logger_mock.run_url = 'fake-url' + mlflow_logger_mock.tracking_uri = None return mlflow_logger_mock @@ -432,10 +436,10 @@ def test_final_register_only( if mlflow_registered_model_name is not None: # We should always attempt to register the model once - assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 + assert mlflow_logger_mock.log_model.call_count == 1 if mlflow_registry_error: # If the registry fails, we should still save the model - assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 + assert mlflow_logger_mock.log_model.call_count == 1 assert checkpointer_callback._save_checkpoint.call_count == 2 assert checkpointer_callback._save_checkpoint.call_args_list[ 0].kwargs == { @@ -457,7 +461,7 @@ def test_final_register_only( } else: # No mlflow_registered_model_name, so we should only save the checkpoint - assert mlflow_logger_mock.register_model_with_run_id.call_count == 0 + assert mlflow_logger_mock.log_model.call_count == 0 assert checkpointer_callback._save_checkpoint.call_count == 1 assert checkpointer_callback._save_checkpoint.call_args_list[ 0].kwargs == { @@ -512,6 +516,7 @@ def test_huggingface_conversion_callback_interval( optimizer = _create_optimizer(original_model) mlflow_logger_mock = _create_mlflow_logger_mock() + checkpointer_callback.transform_model_pre_registration = MagicMock( wraps=checkpointer_callback.transform_model_pre_registration, ) @@ -533,29 +538,33 @@ def test_huggingface_conversion_callback_interval( trainer.fit() if log_to_mlflow: - assert mlflow_logger_mock.save_model.call_count == 1 - mlflow_logger_mock.save_model.assert_called_with( - flavor='transformers', + assert mlflow_logger_mock.log_model.call_count == 1 + mlflow_logger_mock.log_model.assert_called_with( transformers_model=ANY, - path=ANY, - task='llm/v1/completions', - input_example=ANY, - metadata={}, - pip_requirements=ANY, + flavor='transformers', + artifact_path='final_model_checkpoint', + registered_model_name='dummy-registered-name', + run_id='mlflow-run-id', + await_registration_for=3600, + metadata=ANY, + task=ANY, + input_example={ + 'prompt': np.array(['What is Machine Learning?']), + }, ) assert checkpointer_callback.transform_model_pre_registration.call_count == 1 assert checkpointer_callback.pre_register_edit.call_count == 1 - assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 + assert mlflow_logger_mock.log_model.call_count == 1 else: assert checkpointer_callback.transform_model_pre_registration.call_count == 0 assert checkpointer_callback.pre_register_edit.call_count == 0 - assert mlflow_logger_mock.save_model.call_count == 0 - assert mlflow_logger_mock.register_model_with_run_id.call_count == 0 + assert mlflow_logger_mock.log_model.call_count == 0 normal_checkpoints = [ name for name in os.listdir(os.path.join(tmp_path, 'checkpoints')) if name != 'huggingface' ] + huggingface_checkpoints = list( os.listdir(os.path.join(tmp_path, 'checkpoints', 'huggingface')), ) @@ -699,7 +708,6 @@ def _assert_mlflow_logger_calls( peft_config: Optional[dict] = None, ): if dist.get_global_rank() == 0: - assert mlflow_logger_mock.save_model.call_count == 1 if peft_config is not None: expectation = { 'flavor': 'peft', @@ -707,27 +715,26 @@ def _assert_mlflow_logger_calls( 'save_pretrained_dir': ANY, 'metadata': {}, } + assert mlflow_logger_mock.save_model.call_count == 1 else: - import numpy as np - default_input_example = { 'prompt': np.array(['What is Machine Learning?']), } - expectation = { - 'flavor': 'transformers', 'transformers_model': ANY, - 'path': ANY, - 'task': 'llm/v1/completions', + 'flavor': 'transformers', + 'artifact_path': 'final_model_checkpoint', + 'registered_model_name': 'dummy-registered-name', + 'run_id': 'mlflow-run-id', + 'await_registration_for': 3600, + 'metadata': ANY, + 'task': ANY, 'input_example': default_input_example, - 'metadata': {}, - 'pip_requirements': ANY, } - mlflow_logger_mock.save_model.assert_called_with(**expectation) - assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 + assert mlflow_logger_mock.log_model.call_count == 1 + mlflow_logger_mock.log_model.assert_called_with(**expectation) else: assert mlflow_logger_mock.log_model.call_count == 0 - assert mlflow_logger_mock.register_model_with_run_id.call_count == 0 def _get_fsdp_config(fsdp_state_dict_type: Optional[str]): @@ -1039,12 +1046,14 @@ def test_huggingface_conversion_callback( mlflow_logger_mock = MagicMock(spec=MLFlowLogger) mlflow_logger_mock.state_dict = lambda *args, **kwargs: {} mlflow_logger_mock.save_model = MagicMock(wraps=_save_model_mock) - mlflow_logger_mock.register_model_with_run_id = MagicMock() + mlflow_logger_mock.log_model = MagicMock() mlflow_logger_mock.model_registry_prefix = '' mlflow_logger_mock._experiment_id = 'mlflow-experiment-id' mlflow_logger_mock._run_id = 'mlflow-run-id' mlflow_logger_mock._enabled = True mlflow_logger_mock.run_url = 'fake-url' + mlflow_logger_mock.tracking_uri = None + mlflow_logger_mock.model_registry_uri = None trainer = Trainer( model=original_model, device='gpu',