Skip to content

Commit

Permalink
Background mlflow register model
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Mar 30, 2024
1 parent 7a8a156 commit 8594a80
Showing 1 changed file with 44 additions and 6 deletions.
50 changes: 44 additions & 6 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import os
import re
import tempfile
import time
from multiprocessing.context import SpawnProcess
from pathlib import Path
from typing import Any, Dict, Optional, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence, Union

import torch
from composer.core import Callback, Event, State, Time, TimeUnit
Expand Down Expand Up @@ -72,6 +74,19 @@ def _maybe_get_license_filename(
return None


def _register_model_with_run_id_multiprocess(mlflow_logger: MLFlowLogger,
logging_level: int, model_uri: str,
name: str,
await_creation_for: int):
logging.basicConfig(
format=
f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s'
)
logging.getLogger('composer').setLevel(logging_level)
mlflow_logger.register_model_with_run_id(
model_uri=model_uri, name=name, await_creation_for=await_creation_for)


class HuggingFaceCheckpointer(Callback):
"""Save a huggingface formatted checkpoint during training.
Expand Down Expand Up @@ -170,6 +185,7 @@ def __init__(

self.last_checkpoint_batch: Optional[Time] = None
self.mlflow_loggers = []
self.child_processes: List[SpawnProcess] = []

def run_event(self, event: Event, state: State, logger: Logger) -> None:
# The interval scheduler handles only returning True for the appropriate events
Expand Down Expand Up @@ -202,6 +218,10 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
import mlflow
mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set(
'5GB')
elif event == Event.FIT_END:
# Wait for all child processes spawned by the callback to finish.
while not self._all_child_processes_done():
time.sleep(30)

def _is_last_batch(self, state: State):
elapsed_duration = state.get_elapsed_duration()
Expand All @@ -218,6 +238,12 @@ def _is_last_batch(self, state: State):

return False

def _all_child_processes_done(self) -> bool:
not_done = any(process.is_alive() for process in self.child_processes)
x = torch.tensor(1 if not_done else 0).to(device='cuda')
dist.all_reduce(x, reduce_operation='MAX')
return x.item() == 0

def _save_checkpoint(self, state: State, logger: Logger):
del logger # unused

Expand Down Expand Up @@ -385,8 +411,20 @@ def _save_checkpoint(self, state: State, logger: Logger):
os.path.join(local_save_path, license_filename),
)

mlflow_logger.register_model_with_run_id(
model_uri=local_save_path,
name=self.mlflow_registered_model_name,
await_creation_for=3600,
)
# Register the model to mlflow in a child process.
process = SpawnProcess(
target=_register_model_with_run_id_multiprocess,
kwargs={
'mlflow_logger':
mlflow_logger,
'logging_level':
logging.getLogger('composer').level,
'model_uri':
local_save_path,
'name':
self.mlflow_registered_model_name,
'await_creation_for':
3600,
})
process.start()
self.child_processes.append(process)

0 comments on commit 8594a80

Please sign in to comment.