Skip to content

Commit

Permalink
switch to providing registered name
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Oct 2, 2023
1 parent c7161f8 commit c715832
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
29 changes: 15 additions & 14 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ class HuggingFaceCheckpointer(Callback):
huggingface_folder_name (str): Folder to save each checkpoint under (can be a format string). Default is ``ba{batch}``.
precision: The precision to save the model in. Default is ``float32``. Options are ``bfloat16``, ``float16``, or ``float32``.
overwrite (bool): Whether to overwrite previous checkpoints.
log_to_mlflow (bool): Whether to register the model to MLflow. This will only register one model at the end of training. Default is ``False``.
mlflow_registered_model_name (Optional[str]): The name to register the model under in the MLflow model registry. If ``None``, the model will not
be registered. Default is ``None``.
mlflow_logging_config (Optional[dict]): A dictionary of config arguments that will get passed along to the MLflow ``save_model`` call.
Expected to contain ``metadata`` and ``task`` keys. If either is unspecified, the defaults are ``'text-generation'`` and
``{'task': 'llm/v1/completions'}`` respectively.
Expand All @@ -53,7 +54,7 @@ def __init__(
huggingface_folder_name: str = 'ba{batch}',
precision: str = 'float32',
overwrite: bool = False,
log_to_mlflow: bool = False,
mlflow_registered_model_name: Optional[str] = None,
mlflow_logging_config: Optional[dict] = None,
):
self.backend, self.bucket_name, self.save_dir_format_str = parse_uri(
Expand All @@ -67,13 +68,14 @@ def __init__(
}[precision]

# mlflow config setup
self.log_to_mlflow = log_to_mlflow
if mlflow_logging_config is None:
mlflow_logging_config = {}
if 'metadata' not in mlflow_logging_config:
mlflow_logging_config['metadata'] = {'task': 'llm/v1/completions'}
if 'task' not in mlflow_logging_config:
mlflow_logging_config['task'] = 'text-generation'
self.mlflow_registered_model_name = mlflow_registered_model_name
if self.mlflow_registered_model_name is not None:
if mlflow_logging_config is None:
mlflow_logging_config = {}
if 'metadata' not in mlflow_logging_config:
mlflow_logging_config['metadata'] = {'task': 'llm/v1/completions'}
if 'task' not in mlflow_logging_config:
mlflow_logging_config['task'] = 'text-generation'
self.mlflow_logging_config = mlflow_logging_config

self.huggingface_folder_name_fstr = os.path.join(
Expand Down Expand Up @@ -106,7 +108,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
self.remote_ud.init(state, logger)
state.callbacks.append(self.remote_ud)

if self.log_to_mlflow:
if self.mlflow_registered_model_name is not None:
self.mlflow_loggers = [
logger_destination
for logger_destination in logger.destinations
Expand Down Expand Up @@ -216,16 +218,15 @@ def _save_checkpoint(self, state: State, logger: Logger):
)

elapsed_duration = state.get_elapsed_duration()
if self.log_to_mlflow and elapsed_duration is not None and elapsed_duration >= 1.0:
if self.mlflow_registered_model_name is not None and elapsed_duration is not None and elapsed_duration >= 1.0:
components = {'model': new_model_instance}
if original_tokenizer is not None:
components['tokenizer'] = original_tokenizer

log.debug('Logging Hugging Face model to MLFlow')
registered_model_name = f'{state.run_name}_{os.path.basename(save_dir)}'
for i, mlflow_logger in enumerate(self.mlflow_loggers):
log.debug(
f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{registered_model_name}'
f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}'
)
local_save_path = str(
Path(temp_save_dir) / f'mlflow_save_{i}')
Expand All @@ -241,6 +242,6 @@ def _save_checkpoint(self, state: State, logger: Logger):
)
mlflow_logger.register_model(
model_uri=local_save_path,
name=registered_model_name,
name=self.mlflow_registered_model_name,
await_registration_for=3600,
)
2 changes: 1 addition & 1 deletion tests/test_hf_conversion_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
save_folder=os.path.join(tmp_path, 'checkpoints'),
save_interval=f'{huggingface_save_interval_batches}ba',
precision=precision_str,
log_to_mlflow=log_to_mlflow,
mlflow_registered_model_name='dummy-registered-name',
)

# get small version of each model
Expand Down

0 comments on commit c715832

Please sign in to comment.