From 8594a80ed621b6cab75f9017001bc369c0bd1a09 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Sat, 30 Mar 2024 20:29:42 +0000 Subject: [PATCH] Background mlflow register model --- llmfoundry/callbacks/hf_checkpointer.py | 50 ++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 6 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 9eb23f1030..91e5ee6ff9 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -8,8 +8,10 @@ import os import re import tempfile +import time +from multiprocessing.context import SpawnProcess from pathlib import Path -from typing import Any, Dict, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Union import torch from composer.core import Callback, Event, State, Time, TimeUnit @@ -72,6 +74,19 @@ def _maybe_get_license_filename( return None +def _register_model_with_run_id_multiprocess(mlflow_logger: MLFlowLogger, + logging_level: int, model_uri: str, + name: str, + await_creation_for: int): + logging.basicConfig( + format= + f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s' + ) + logging.getLogger('composer').setLevel(logging_level) + mlflow_logger.register_model_with_run_id( + model_uri=model_uri, name=name, await_creation_for=await_creation_for) + + class HuggingFaceCheckpointer(Callback): """Save a huggingface formatted checkpoint during training. @@ -170,6 +185,7 @@ def __init__( self.last_checkpoint_batch: Optional[Time] = None self.mlflow_loggers = [] + self.child_processes: List[SpawnProcess] = [] def run_event(self, event: Event, state: State, logger: Logger) -> None: # The interval scheduler handles only returning True for the appropriate events @@ -202,6 +218,10 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: import mlflow mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set( '5GB') + elif event == Event.FIT_END: + # Wait for all child processes spawned by the callback to finish. + while not self._all_child_processes_done(): + time.sleep(30) def _is_last_batch(self, state: State): elapsed_duration = state.get_elapsed_duration() @@ -218,6 +238,12 @@ def _is_last_batch(self, state: State): return False + def _all_child_processes_done(self) -> bool: + not_done = any(process.is_alive() for process in self.child_processes) + x = torch.tensor(1 if not_done else 0).to(device='cuda') + dist.all_reduce(x, reduce_operation='MAX') + return x.item() == 0 + def _save_checkpoint(self, state: State, logger: Logger): del logger # unused @@ -385,8 +411,20 @@ def _save_checkpoint(self, state: State, logger: Logger): os.path.join(local_save_path, license_filename), ) - mlflow_logger.register_model_with_run_id( - model_uri=local_save_path, - name=self.mlflow_registered_model_name, - await_creation_for=3600, - ) + # Register the model to mlflow in a child process. + process = SpawnProcess( + target=_register_model_with_run_id_multiprocess, + kwargs={ + 'mlflow_logger': + mlflow_logger, + 'logging_level': + logging.getLogger('composer').level, + 'model_uri': + local_save_path, + 'name': + self.mlflow_registered_model_name, + 'await_creation_for': + 3600, + }) + process.start() + self.child_processes.append(process)