From 8594a80ed621b6cab75f9017001bc369c0bd1a09 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Sat, 30 Mar 2024 20:29:42 +0000 Subject: [PATCH 01/14] Background mlflow register model --- llmfoundry/callbacks/hf_checkpointer.py | 50 ++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 6 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 9eb23f1030..91e5ee6ff9 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -8,8 +8,10 @@ import os import re import tempfile +import time +from multiprocessing.context import SpawnProcess from pathlib import Path -from typing import Any, Dict, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Union import torch from composer.core import Callback, Event, State, Time, TimeUnit @@ -72,6 +74,19 @@ def _maybe_get_license_filename( return None +def _register_model_with_run_id_multiprocess(mlflow_logger: MLFlowLogger, + logging_level: int, model_uri: str, + name: str, + await_creation_for: int): + logging.basicConfig( + format= + f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s' + ) + logging.getLogger('composer').setLevel(logging_level) + mlflow_logger.register_model_with_run_id( + model_uri=model_uri, name=name, await_creation_for=await_creation_for) + + class HuggingFaceCheckpointer(Callback): """Save a huggingface formatted checkpoint during training. @@ -170,6 +185,7 @@ def __init__( self.last_checkpoint_batch: Optional[Time] = None self.mlflow_loggers = [] + self.child_processes: List[SpawnProcess] = [] def run_event(self, event: Event, state: State, logger: Logger) -> None: # The interval scheduler handles only returning True for the appropriate events @@ -202,6 +218,10 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: import mlflow mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set( '5GB') + elif event == Event.FIT_END: + # Wait for all child processes spawned by the callback to finish. + while not self._all_child_processes_done(): + time.sleep(30) def _is_last_batch(self, state: State): elapsed_duration = state.get_elapsed_duration() @@ -218,6 +238,12 @@ def _is_last_batch(self, state: State): return False + def _all_child_processes_done(self) -> bool: + not_done = any(process.is_alive() for process in self.child_processes) + x = torch.tensor(1 if not_done else 0).to(device='cuda') + dist.all_reduce(x, reduce_operation='MAX') + return x.item() == 0 + def _save_checkpoint(self, state: State, logger: Logger): del logger # unused @@ -385,8 +411,20 @@ def _save_checkpoint(self, state: State, logger: Logger): os.path.join(local_save_path, license_filename), ) - mlflow_logger.register_model_with_run_id( - model_uri=local_save_path, - name=self.mlflow_registered_model_name, - await_creation_for=3600, - ) + # Register the model to mlflow in a child process. + process = SpawnProcess( + target=_register_model_with_run_id_multiprocess, + kwargs={ + 'mlflow_logger': + mlflow_logger, + 'logging_level': + logging.getLogger('composer').level, + 'model_uri': + local_save_path, + 'name': + self.mlflow_registered_model_name, + 'await_creation_for': + 3600, + }) + process.start() + self.child_processes.append(process) From 62f79302bed7836c9429ba2e324a217659163b2f Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Sat, 30 Mar 2024 20:36:18 +0000 Subject: [PATCH 02/14] Add comments --- llmfoundry/callbacks/hf_checkpointer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 91e5ee6ff9..7104d4d0f2 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -78,11 +78,18 @@ def _register_model_with_run_id_multiprocess(mlflow_logger: MLFlowLogger, logging_level: int, model_uri: str, name: str, await_creation_for: int): + """Function for calling MLFlowLogger.register_model_with_run_id from a. + + spawned child process. + """ + # Setup logging for child process. This ensures that any logs from composer are surfaced. logging.basicConfig( format= f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s' ) logging.getLogger('composer').setLevel(logging_level) + + # Register model. mlflow_logger.register_model_with_run_id( model_uri=model_uri, name=name, await_creation_for=await_creation_for) @@ -411,7 +418,7 @@ def _save_checkpoint(self, state: State, logger: Logger): os.path.join(local_save_path, license_filename), ) - # Register the model to mlflow in a child process. + # Spawn a new process to register the model. process = SpawnProcess( target=_register_model_with_run_id_multiprocess, kwargs={ From 1b875ceb7ca77e2a40f5f3fcdcd89c571dad1708 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Sat, 30 Mar 2024 20:40:42 +0000 Subject: [PATCH 03/14] Change sleep time to 2 to match remote uploader from composer --- llmfoundry/callbacks/hf_checkpointer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 7104d4d0f2..907b08b01c 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -228,7 +228,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: elif event == Event.FIT_END: # Wait for all child processes spawned by the callback to finish. while not self._all_child_processes_done(): - time.sleep(30) + time.sleep(2) def _is_last_batch(self, state: State): elapsed_duration = state.get_elapsed_duration() From c9c22a7ffaf0c3f834684e3af24b73216d1c4b68 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Sun, 31 Mar 2024 01:05:18 +0000 Subject: [PATCH 04/14] Fix tests --- .../inference/test_convert_composer_to_hf.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 3949c091aa..9d8a0c747c 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -7,7 +7,7 @@ import pathlib import shutil from argparse import Namespace -from typing import Any, Callable, Optional, cast +from typing import Any, Callable, Dict, Optional, cast from unittest.mock import ANY, MagicMock, patch import pytest @@ -256,12 +256,34 @@ def test_callback_inits(): assert hf_checkpointer.mlflow_logging_config['task'] == 'llm/v1/completions' +class MockSpawnProcess: + """Class for mocking `multiprocessing.context.SpawnProcess`. + + Runs `target(**kwargs)` on the main process. + + Mock classes are not picklable and therefore cannot be used with + multiprocessing, so we need to patch SpawnProcess for tests. + """ + + def __init__(self, target: Callable, kwargs: Dict[str, Any]): + self.target = target + self.kwargs = kwargs + + def start(self): + self.target(**self.kwargs) + + def is_alive(self) -> bool: + return False + + @pytest.mark.gpu @pytest.mark.parametrize('log_to_mlflow', [True, False]) @pytest.mark.parametrize( 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', [('3ba', '2ba', '4ba', 2, 2), ('1dur', '2ba', '1ep', 1, 2)]) @patch('os.cpu_count', MagicMock(return_value=1)) +@patch('llmfoundry.callbacks.hf_checkpointer.SpawnProcess', + new=MockSpawnProcess) def test_huggingface_conversion_callback_interval( tmp_path: pathlib.Path, log_to_mlflow: bool, hf_save_interval: str, save_interval: str, max_duration: str, expected_hf_checkpoints: int, From 1cdc78dd61441a7d811b82c928c920731c78162e Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Sun, 31 Mar 2024 01:55:13 +0000 Subject: [PATCH 05/14] Add barriers and manually clean up tempdir --- llmfoundry/callbacks/hf_checkpointer.py | 334 ++++++++++++------------ 1 file changed, 173 insertions(+), 161 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 907b08b01c..fd60349fd4 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -7,6 +7,7 @@ import math import os import re +import shutil import tempfile import time from multiprocessing.context import SpawnProcess @@ -192,7 +193,10 @@ def __init__( self.last_checkpoint_batch: Optional[Time] = None self.mlflow_loggers = [] + self.child_processes: List[SpawnProcess] = [] + # Temporary save directory used by child_processes. + self.temp_save_dir = None def run_event(self, event: Event, state: State, logger: Logger) -> None: # The interval scheduler handles only returning True for the appropriate events @@ -230,6 +234,10 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: while not self._all_child_processes_done(): time.sleep(2) + # Clean up temporary save directory; all processes are done with it. + if self.temp_save_dir is not None: + shutil.rmtree(self.temp_save_dir) + def _is_last_batch(self, state: State): elapsed_duration = state.get_elapsed_duration() if elapsed_duration is not None and elapsed_duration >= 1.0: @@ -246,6 +254,8 @@ def _is_last_batch(self, state: State): return False def _all_child_processes_done(self) -> bool: + if len(self.child_processes) == 0: + return True not_done = any(process.is_alive() for process in self.child_processes) x = torch.tensor(1 if not_done else 0).to(device='cuda') dist.all_reduce(x, reduce_operation='MAX') @@ -268,170 +278,172 @@ def _save_checkpoint(self, state: State, logger: Logger): Path(self.save_dir_format_str) / self.huggingface_folder_name_fstr), state.run_name, state.timestamp) - dir_context_mgr = tempfile.TemporaryDirectory( - ) if self.remote_ud is not None else contextlib.nullcontext( - enter_result=save_dir) - - with dir_context_mgr as temp_save_dir: - assert isinstance(temp_save_dir, - str) # pyright doesn't know about enter_result - - log.debug('Gathering state dict') - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - - if state.is_model_ddp: - composer_model = state.model.module - original_model: PreTrainedModel = state.model.module.model - state_dict_model = state.model.module.model - original_tokenizer = state.model.module.tokenizer - elif isinstance(state.model.model, FSDP): - composer_model = state.model - original_model: PreTrainedModel = state.model.model.module - state_dict_model = state.model.model - original_tokenizer = state.model.tokenizer + + temp_save_dir = tempfile.mkdtemp( + ) if self.remote_ud is not None else save_dir + + log.debug('Gathering state dict') + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + + if state.is_model_ddp: + composer_model = state.model.module + original_model: PreTrainedModel = state.model.module.model + state_dict_model = state.model.module.model + original_tokenizer = state.model.module.tokenizer + elif isinstance(state.model.model, FSDP): + composer_model = state.model + original_model: PreTrainedModel = state.model.model.module + state_dict_model = state.model.model + original_tokenizer = state.model.tokenizer + else: + composer_model = state.model + original_model: PreTrainedModel = state.model.model + state_dict_model = state.model.model + original_tokenizer = state.model.tokenizer + + state_dict_context = fsdp_state_dict_type_context( + original_model, + state_dict_type='full') if ((not state.is_model_ddp) and isinstance( + state_dict_model, FSDP)) else contextlib.nullcontext() + + with state_dict_context: + state_dict = state_dict_model.state_dict() + + # convert the state dict to the requested precision + for k, v in state_dict.items(): + if isinstance(v, torch.Tensor): + state_dict[k] = v.to(dtype=self.dtype) + + new_model_instance = None # Need this for pyright because variable could be unbound + + if dist.get_global_rank() == 0: + log.debug('Saving Hugging Face checkpoint in global rank 0') + + copied_config = copy.deepcopy(original_model.config) + if copied_config.model_type == 'mpt': + copied_config.attn_config['attn_impl'] = 'torch' + copied_config.init_device = 'cpu' + + log.debug(f'Creating new model instance') + + if composer_model.using_peft: + # We don't use meta here because the state dict does not contain the full + # model, only the adapter weights. + active_adapter = original_model.active_adapter + base_model = original_model.get_base_model() + new_base_model_instance = type(base_model)(copied_config) + + new_model_instance = type(original_model)( + new_base_model_instance, + original_model.peft_config[active_adapter]) + new_model_instance.to(dtype=self.dtype) else: - composer_model = state.model - original_model: PreTrainedModel = state.model.model - state_dict_model = state.model.model - original_tokenizer = state.model.tokenizer - - state_dict_context = fsdp_state_dict_type_context( - original_model, state_dict_type='full') if ( - (not state.is_model_ddp) and isinstance( - state_dict_model, FSDP)) else contextlib.nullcontext() - - with state_dict_context: - state_dict = state_dict_model.state_dict() - - # convert the state dict to the requested precision - for k, v in state_dict.items(): - if isinstance(v, torch.Tensor): - state_dict[k] = v.to(dtype=self.dtype) - - if dist.get_global_rank() == 0: - log.debug('Saving Hugging Face checkpoint in global rank 0') - - copied_config = copy.deepcopy(original_model.config) - if copied_config.model_type == 'mpt': - copied_config.attn_config['attn_impl'] = 'torch' - copied_config.init_device = 'cpu' - - log.debug(f'Creating new model instance') - - if composer_model.using_peft: - # We don't use meta here because the state dict does not contain the full - # model, only the adapter weights. - active_adapter = original_model.active_adapter - base_model = original_model.get_base_model() - new_base_model_instance = type(base_model)(copied_config) - - new_model_instance = type(original_model)( - new_base_model_instance, - original_model.peft_config[active_adapter]) - new_model_instance.to(dtype=self.dtype) - else: - # First create the model instance on meta device to avoid the - # initialization cost. - with init_empty_weights(): - new_model_instance = type(original_model)(copied_config) - - # Then load the state dict in with "assign" so that the state dict - # is loaded properly even though the model is initially on meta device. - new_model_instance.load_state_dict(state_dict, assign=True) - del state_dict - - log.debug('Saving Hugging Face checkpoint to disk') - new_model_instance.save_pretrained(temp_save_dir) - if original_tokenizer is not None: - assert isinstance(original_tokenizer, - PreTrainedTokenizerBase) - original_tokenizer.save_pretrained(temp_save_dir) - - # Only need to edit files for MPT because it has custom code - if original_model.config.model_type == 'mpt': - log.debug('Editing MPT files for HuggingFace compatibility') - edit_files_for_hf_compatibility( - temp_save_dir, - self.flatten_imports, + # First create the model instance on meta device to avoid the + # initialization cost. + with init_empty_weights(): + new_model_instance = type(original_model)(copied_config) + + # Then load the state dict in with "assign" so that the state dict + # is loaded properly even though the model is initially on meta device. + new_model_instance.load_state_dict(state_dict, assign=True) + del state_dict + + log.debug('Saving Hugging Face checkpoint to disk') + new_model_instance.save_pretrained(temp_save_dir) + if original_tokenizer is not None: + assert isinstance(original_tokenizer, PreTrainedTokenizerBase) + original_tokenizer.save_pretrained(temp_save_dir) + + # Only need to edit files for MPT because it has custom code + if original_model.config.model_type == 'mpt': + log.debug('Editing MPT files for HuggingFace compatibility') + edit_files_for_hf_compatibility( + temp_save_dir, + self.flatten_imports, + ) + + if self.remote_ud is not None: + for filename in os.listdir(temp_save_dir): + remote_file_name = os.path.join(save_dir, filename) + remote_file_uri = self.remote_ud.remote_backend.get_uri( + remote_file_name) + log.info( + f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}' + ) + self.remote_ud.upload_file( + state=state, + remote_file_name=remote_file_name, + file_path=Path(os.path.join(temp_save_dir, filename)), + overwrite=self.overwrite, ) - if self.remote_ud is not None: - for filename in os.listdir(temp_save_dir): - remote_file_name = os.path.join(save_dir, filename) - remote_file_uri = self.remote_ud.remote_backend.get_uri( - remote_file_name) - log.info( - f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}' - ) - self.remote_ud.upload_file( - state=state, - remote_file_name=remote_file_name, - file_path=Path(os.path.join(temp_save_dir, - filename)), - overwrite=self.overwrite, - ) + dist.barrier() - if self.mlflow_registered_model_name and self._is_last_batch( - state): - components = {'model': new_model_instance} - if original_tokenizer is not None: - components['tokenizer'] = original_tokenizer + if dist.get_global_rank() == 0: + if self.mlflow_registered_model_name and self._is_last_batch(state): + components = {'model': new_model_instance} + if original_tokenizer is not None: + components['tokenizer'] = original_tokenizer - log.debug('Logging Hugging Face model to MLFlow') - for i, mlflow_logger in enumerate(self.mlflow_loggers): - log.debug( - f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}' + log.debug('Logging Hugging Face model to MLFlow') + for i, mlflow_logger in enumerate(self.mlflow_loggers): + log.debug( + f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}' + ) + local_save_path = str( + Path(temp_save_dir) / f'mlflow_save_{i}') + + # TODO: Remove after mlflow fixes the bug that makes this necessary + import mlflow + mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: '' + model_saving_kwargs: Dict[str, Any] = { + 'path': local_save_path + } + if composer_model.using_peft: + model_saving_kwargs['flavor'] = 'peft' + model_saving_kwargs[ + 'save_pretrained_dir'] = temp_save_dir + model_saving_kwargs[ + 'metadata'] = self.mlflow_logging_config['metadata'] + else: + model_saving_kwargs['flavor'] = 'transformers' + model_saving_kwargs['transformers_model'] = components + model_saving_kwargs.update(self.mlflow_logging_config) + + mlflow_logger.save_model(**model_saving_kwargs) + + # Upload the license file generated by mlflow during the model saving. + license_filename = _maybe_get_license_filename( + local_save_path, + self.mlflow_logging_config['metadata'].get( + 'pretrained_model_name', None)) + if license_filename is not None: + mlflow_logger._mlflow_client.log_artifact( + mlflow_logger._run_id, + os.path.join(local_save_path, license_filename), ) - local_save_path = str( - Path(temp_save_dir) / f'mlflow_save_{i}') - - # TODO: Remove after mlflow fixes the bug that makes this necessary - import mlflow - mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: '' - model_saving_kwargs: Dict[str, Any] = { - 'path': local_save_path - } - if composer_model.using_peft: - model_saving_kwargs['flavor'] = 'peft' - model_saving_kwargs[ - 'save_pretrained_dir'] = temp_save_dir - model_saving_kwargs[ - 'metadata'] = self.mlflow_logging_config[ - 'metadata'] - else: - model_saving_kwargs['flavor'] = 'transformers' - model_saving_kwargs[ - 'transformers_model'] = components - model_saving_kwargs.update( - self.mlflow_logging_config) - - mlflow_logger.save_model(**model_saving_kwargs) - - # Upload the license file generated by mlflow during the model saving. - license_filename = _maybe_get_license_filename( - local_save_path, - self.mlflow_logging_config['metadata'].get( - 'pretrained_model_name', None)) - if license_filename is not None: - mlflow_logger._mlflow_client.log_artifact( - mlflow_logger._run_id, - os.path.join(local_save_path, license_filename), - ) - - # Spawn a new process to register the model. - process = SpawnProcess( - target=_register_model_with_run_id_multiprocess, - kwargs={ - 'mlflow_logger': - mlflow_logger, - 'logging_level': - logging.getLogger('composer').level, - 'model_uri': - local_save_path, - 'name': - self.mlflow_registered_model_name, - 'await_creation_for': - 3600, - }) - process.start() - self.child_processes.append(process) + + # Spawn a new process to register the model. + process = SpawnProcess( + target=_register_model_with_run_id_multiprocess, + kwargs={ + 'mlflow_logger': + mlflow_logger, + 'logging_level': + logging.getLogger('composer').level, + 'model_uri': + local_save_path, + 'name': + self.mlflow_registered_model_name, + 'await_creation_for': + 3600, + }) + process.start() + self.child_processes.append(process) + + # Save the temporary directory to be cleaned up later. + self.temp_save_dir = temp_save_dir + else: + # Clean up the temporary directory if we don't need to register to mlflow. + shutil.rmtree(temp_save_dir) + dist.barrier() From 2ac1588872eab087c331bddb4f8849fdfbf43b1f Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Sun, 31 Mar 2024 05:36:11 +0000 Subject: [PATCH 06/14] Fix test with mocking --- tests/a_scripts/inference/test_convert_composer_to_hf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 9d8a0c747c..3c6d9425f7 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -428,6 +428,8 @@ def test_huggingface_conversion_callback_interval( 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', [('1ba', '1ba', '1ba', 1, 1)]) @patch('os.cpu_count', MagicMock(return_value=1)) +@patch('llmfoundry.callbacks.hf_checkpointer.SpawnProcess', + new=MockSpawnProcess) def test_huggingface_conversion_callback( model: str, tmp_path: pathlib.Path, From 8b403ec0c51b4b805f3a181cd0ae1c59d26fe78e Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Sun, 31 Mar 2024 18:54:56 +0000 Subject: [PATCH 07/14] Fix directory removal and nccl timeout --- llmfoundry/callbacks/hf_checkpointer.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index fd60349fd4..3b83a27cab 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -254,8 +254,6 @@ def _is_last_batch(self, state: State): return False def _all_child_processes_done(self) -> bool: - if len(self.child_processes) == 0: - return True not_done = any(process.is_alive() for process in self.child_processes) x = torch.tensor(1 if not_done else 0).to(device='cuda') dist.all_reduce(x, reduce_operation='MAX') @@ -279,8 +277,10 @@ def _save_checkpoint(self, state: State, logger: Logger): self.huggingface_folder_name_fstr), state.run_name, state.timestamp) + # Use a temporary directory if save_dir is remote. + use_temp_dir = self.remote_ud is not None temp_save_dir = tempfile.mkdtemp( - ) if self.remote_ud is not None else save_dir + ) if use_temp_dir else save_dir log.debug('Gathering state dict') from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -442,8 +442,10 @@ def _save_checkpoint(self, state: State, logger: Logger): self.child_processes.append(process) # Save the temporary directory to be cleaned up later. - self.temp_save_dir = temp_save_dir + if use_temp_dir: + self.temp_save_dir = temp_save_dir else: # Clean up the temporary directory if we don't need to register to mlflow. - shutil.rmtree(temp_save_dir) + if use_temp_dir: + shutil.rmtree(temp_save_dir) dist.barrier() From 9193461a92010628613d17c4b1964875a8990997 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Sun, 31 Mar 2024 18:57:44 +0000 Subject: [PATCH 08/14] code quality --- llmfoundry/callbacks/hf_checkpointer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 3b83a27cab..b1ed73498e 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -277,10 +277,9 @@ def _save_checkpoint(self, state: State, logger: Logger): self.huggingface_folder_name_fstr), state.run_name, state.timestamp) - # Use a temporary directory if save_dir is remote. + # Use a temporary directory if save_dir is remote. use_temp_dir = self.remote_ud is not None - temp_save_dir = tempfile.mkdtemp( - ) if use_temp_dir else save_dir + temp_save_dir = tempfile.mkdtemp() if use_temp_dir else save_dir log.debug('Gathering state dict') from torch.distributed.fsdp import FullyShardedDataParallel as FSDP From 8e6af116b41c5ccdf6198702d7329e6584da7a44 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Tue, 2 Apr 2024 15:53:22 +0000 Subject: [PATCH 09/14] Only log if logging level is greater than 0 --- llmfoundry/callbacks/hf_checkpointer.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index b1ed73498e..90189ddeeb 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -76,19 +76,21 @@ def _maybe_get_license_filename( def _register_model_with_run_id_multiprocess(mlflow_logger: MLFlowLogger, - logging_level: int, model_uri: str, - name: str, + composer_logging_level: int, + model_uri: str, name: str, await_creation_for: int): """Function for calling MLFlowLogger.register_model_with_run_id from a. spawned child process. """ # Setup logging for child process. This ensures that any logs from composer are surfaced. - logging.basicConfig( - format= - f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s' - ) - logging.getLogger('composer').setLevel(logging_level) + if composer_logging_level > 0: + # If logging_level is 0, then the composer logger was unset. + logging.basicConfig( + format= + f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s' + ) + logging.getLogger('composer').setLevel(composer_logging_level) # Register model. mlflow_logger.register_model_with_run_id( From e63c4dadec1b9aebad8a9dac45e191c1b2c51e95 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Tue, 2 Apr 2024 15:58:53 +0000 Subject: [PATCH 10/14] Add timeout --- llmfoundry/callbacks/hf_checkpointer.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 90189ddeeb..2ff39a51fa 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -75,10 +75,13 @@ def _maybe_get_license_filename( return None -def _register_model_with_run_id_multiprocess(mlflow_logger: MLFlowLogger, - composer_logging_level: int, - model_uri: str, name: str, - await_creation_for: int): +def _register_model_with_run_id_multiprocess( + mlflow_logger: MLFlowLogger, + composer_logging_level: int, + model_uri: str, + name: str, + await_creation_for: int, +): """Function for calling MLFlowLogger.register_model_with_run_id from a. spawned child process. @@ -233,7 +236,14 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: '5GB') elif event == Event.FIT_END: # Wait for all child processes spawned by the callback to finish. + timeout = 3600 + wait_start = time.time() while not self._all_child_processes_done(): + wait_time = time.time() - wait_start + if wait_time > timeout: + raise TimeoutError( + f'Waited {wait_time} seconds for child processes to complete. Exceed timeout of {timeout} seconds.' + ) time.sleep(2) # Clean up temporary save directory; all processes are done with it. From 216425c5488f61a95abea2446cced64203e33e1f Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Tue, 2 Apr 2024 08:59:31 -0700 Subject: [PATCH 11/14] Apply comment suggestions from code review Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/callbacks/hf_checkpointer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 2ff39a51fa..c10944ca71 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -82,9 +82,9 @@ def _register_model_with_run_id_multiprocess( name: str, await_creation_for: int, ): - """Function for calling MLFlowLogger.register_model_with_run_id from a. + """Call MLFlowLogger.register_model_with_run_id. - spawned child process. + Used mainly to register from a child process. """ # Setup logging for child process. This ensures that any logs from composer are surfaced. if composer_logging_level > 0: From 187eb1b4931d06ef45409a69b870e3bae2dc94a2 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Tue, 2 Apr 2024 16:05:44 +0000 Subject: [PATCH 12/14] Fix typo --- llmfoundry/callbacks/hf_checkpointer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index c10944ca71..c5fb160c07 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -82,7 +82,7 @@ def _register_model_with_run_id_multiprocess( name: str, await_creation_for: int, ): - """Call MLFlowLogger.register_model_with_run_id. + """Call MLFlowLogger.register_model_with_run_id Used mainly to register from a child process. """ @@ -440,7 +440,7 @@ def _save_checkpoint(self, state: State, logger: Logger): kwargs={ 'mlflow_logger': mlflow_logger, - 'logging_level': + 'composer_logging_level': logging.getLogger('composer').level, 'model_uri': local_save_path, From d4d610732247d5eefebc581aeea87c60ab708d18 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Tue, 2 Apr 2024 17:16:00 +0000 Subject: [PATCH 13/14] code quality --- llmfoundry/callbacks/hf_checkpointer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index c5fb160c07..bee2ff7f64 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -82,7 +82,7 @@ def _register_model_with_run_id_multiprocess( name: str, await_creation_for: int, ): - """Call MLFlowLogger.register_model_with_run_id + """Call MLFlowLogger.register_model_with_run_id. Used mainly to register from a child process. """ From fbdbadbc76e7785bc3d804247d26bcaf87ab953b Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Tue, 2 Apr 2024 14:49:07 -0700 Subject: [PATCH 14/14] Fix typo Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/callbacks/hf_checkpointer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index bee2ff7f64..0d57feaef4 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -242,7 +242,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: wait_time = time.time() - wait_start if wait_time > timeout: raise TimeoutError( - f'Waited {wait_time} seconds for child processes to complete. Exceed timeout of {timeout} seconds.' + f'Waited {wait_time} seconds for child processes to complete. Exceeded timeout of {timeout} seconds.' ) time.sleep(2)