Skip to content

Commit

Permalink
Timeout works
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Mar 29, 2024
1 parent 3906f5a commit 8277e47
Showing 1 changed file with 38 additions and 24 deletions.
62 changes: 38 additions & 24 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
import os
import re
import tempfile
from pathlib import Path
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Union

from mlflow import MlflowException
import torch
from composer.core import Callback, Event, State, Time, TimeUnit
from composer.core.state import fsdp_state_dict_type_context
Expand All @@ -22,6 +21,7 @@
maybe_create_remote_uploader_downloader_from_uri,
parse_uri)
from composer.utils.misc import create_interval_scheduler
from mlflow import MlflowException
from mlflow.transformers import _fetch_model_card, _write_license_information
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import PreTrainedModel, PreTrainedTokenizerBase
Expand Down Expand Up @@ -173,11 +173,11 @@ def _save_hf_checkpoint(save_path: str, model: PreTrainedModel,
)


def _register_mlflow_model(
mlflow_loggers: List[MLFlowLogger], logging_config: dict,
registered_model_name: str, model: PreTrainedModel,
tokenizer: Optional[PreTrainedTokenizerBase], local_save_path: str,
using_peft: bool):
def _register_mlflow_model(mlflow_loggers: List[MLFlowLogger],
logging_config: dict, registered_model_name: str,
model: PreTrainedModel,
tokenizer: Optional[PreTrainedTokenizerBase],
local_save_path: str, using_peft: bool):
components = {'model': model}
if tokenizer is not None:
components['tokenizer'] = tokenizer
Expand All @@ -203,7 +203,6 @@ def _register_mlflow_model(
model_saving_kwargs['transformers_model'] = components
model_saving_kwargs.update(logging_config)

print("HEY", model_saving_kwargs)
mlflow_logger.save_model(**model_saving_kwargs)

# Upload the license file generated by mlflow during the model saving.
Expand Down Expand Up @@ -401,6 +400,7 @@ def _save_checkpoint(self, state: State, logger: Logger):
state_dict = _get_state_dict(state, original_model)
using_peft = _using_peft(state)

new_model_instance = None
if dist.get_global_rank() == 0:
log.debug('Saving Hugging Face checkpoint in global rank 0')

Expand All @@ -411,7 +411,7 @@ def _save_checkpoint(self, state: State, logger: Logger):
log.debug('Saving Hugging Face checkpoint to disk')
_save_hf_checkpoint(temp_save_dir, new_model_instance,
original_tokenizer, self.flatten_imports)

if self.remote_ud is not None:
for filename in os.listdir(temp_save_dir):
remote_file_name = os.path.join(save_dir, filename)
Expand All @@ -430,19 +430,30 @@ def _save_checkpoint(self, state: State, logger: Logger):

dist.barrier()

if self.mlflow_registered_model_name and self._is_last_batch(
state):
# def get_last_updated_timestamp(logger: MLFlowLogger) -> Optional[int]:
# from mlflow.protos.databricks_pb2 import RESOURCE_ALREADY_EXISTS, ErrorCode
# try:
# return logger._mlflow_client.get_registered_model(f'{logger.model_registry_prefix}.{self.mlflow_registered_model_name}').last_updated_timestamp
# except MlflowException as e:
# print('error code', e.error_code)
# return None
# # TODO: Make timestamp helper on mlflow logger
# get_last_updated = lambda: [get_last_updated_timestamp(mlflow_logger) for mlflow_logger in self.mlflow_loggers]
# last_updated = get_last_updated()
if self.mlflow_registered_model_name and self._is_last_batch(state):
# TODO: Make timestamp helper on mlflow logger, This requires creating a client for each logger
def get_last_updated_timestamp(
logger: MLFlowLogger) -> Optional[int]:
from mlflow.protos.databricks_pb2 import (
RESOURCE_DOES_NOT_EXIST, ErrorCode)
try:
return logger._mlflow_client.get_registered_model(
f'{logger.model_registry_prefix}.{self.mlflow_registered_model_name}'
).last_updated_timestamp
except MlflowException as e:
if e.error_code == ErrorCode.Name(
RESOURCE_DOES_NOT_EXIST):
return None
raise e

get_last_updated = lambda: [
get_last_updated_timestamp(mlflow_logger)
for mlflow_logger in self.mlflow_loggers
]
last_updated = get_last_updated()

if dist.get_global_rank() == 0:
assert new_model_instance is not None
_register_mlflow_model(
mlflow_loggers=self.mlflow_loggers,
logging_config=self.mlflow_logging_config,
Expand All @@ -452,8 +463,11 @@ def _save_checkpoint(self, state: State, logger: Logger):
local_save_path=temp_save_dir,
using_peft=using_peft,
)
# else:
# while any([last_updated[i] == timestamp for i, timestamp in enumerate(get_last_updated())]):
# time.sleep(60)
else:
while any([
last_updated[i] == timestamp
for i, timestamp in enumerate(get_last_updated())
]):
time.sleep(60)

dist.barrier()

0 comments on commit 8277e47

Please sign in to comment.