Skip to content

Commit

Permalink
Add HuggingFaceCheckpointer option for only registering final checkpo…
Browse files Browse the repository at this point in the history
…int (#1516)
  • Loading branch information
irenedea authored Sep 12, 2024
1 parent dab768f commit a862d6e
Show file tree
Hide file tree
Showing 3 changed files with 257 additions and 64 deletions.
160 changes: 114 additions & 46 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import shutil
import tempfile
import time
import warnings
from multiprocessing.context import SpawnProcess
from pathlib import Path
from typing import Any, Optional, Sequence, Union
Expand All @@ -18,6 +19,7 @@
import torch
import torch.nn as nn
from composer.core import Callback, Event, Precision, State, Time, TimeUnit
from composer.devices import Device
from composer.loggers import Logger, MLFlowLogger
from composer.models import HuggingFaceModel
from composer.utils import (
Expand Down Expand Up @@ -161,6 +163,10 @@ class HuggingFaceCheckpointer(Callback):
keys ``input_example`` and ``signature``.
flatten_imports (Sequence[str]): A sequence of import prefixes that will
be flattened when editing MPT files.
final_register_only (bool): If true, only register the model in the MLFlow
registry on the last batch and do not save the HuggingFace checkpoint. If
registration fails or mlflow_registered_model_name is not set, then we will
fallback to saving the HuggingFace checkpoint.
"""

def __init__(
Expand All @@ -173,6 +179,7 @@ def __init__(
mlflow_registered_model_name: Optional[str] = None,
mlflow_logging_config: Optional[dict] = None,
flatten_imports: Sequence[str] = ('llmfoundry',),
final_register_only: bool = False,
):
_, _, self.save_dir_format_str = parse_uri(save_folder)
self.overwrite = overwrite
Expand All @@ -185,8 +192,18 @@ def __init__(
self.flatten_imports = flatten_imports
self.using_peft = False

# mlflow config setup
self.final_register_only = final_register_only

self.mlflow_registered_model_name = mlflow_registered_model_name
if self.final_register_only and self.mlflow_registered_model_name is None:
self.final_register_only = False
warnings.warn(
'final_register_only is set to True, but mlflow_registered_model_name is not set. '
+
f'Defaulting to final_register_only=False and saving the HuggingFace checkpoint to {save_folder=}.',
)

# mlflow config setup
if mlflow_logging_config is None:
mlflow_logging_config = {}
if self.mlflow_registered_model_name is not None:
Expand Down Expand Up @@ -249,7 +266,7 @@ def __init__(
self.last_checkpoint_batch: Optional[Time] = None
self.mlflow_loggers = []

self.child_processes: list[SpawnProcess] = []
self.register_processes: list[SpawnProcess] = []
# Temporary save directory used by child_processes.
self.temp_save_dir = None

Expand All @@ -259,7 +276,17 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
state,
event,
) and self.last_checkpoint_batch != state.timestamp.batch:
self._save_checkpoint(state, logger)
is_last_batch = self._is_last_batch(state)
self._save_checkpoint(
state,
logger,
register_to_mlflow=(
self.mlflow_registered_model_name is not None and
is_last_batch
),
upload_to_save_folder=not self.final_register_only or
not is_last_batch,
)
elif event == Event.INIT:
if not isinstance(state.model, HuggingFaceModel):
raise ValueError(
Expand Down Expand Up @@ -300,14 +327,27 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
# Wait for all child processes spawned by the callback to finish.
timeout = 3600
wait_start = time.time()
while not self._all_child_processes_done():
while not self._all_register_processes_done(state.device):
wait_time = time.time() - wait_start
if wait_time > timeout:
raise TimeoutError(
f'Waited {wait_time} seconds for child processes to complete. Exceeded timeout of {timeout} seconds.',
)
time.sleep(2)

if self._any_register_processes_error(
state.device,
) and self.final_register_only:
log.error(
'An error occurred in one or more registration processes. Fallback to saving the HuggingFace checkpoint.',
)
self._save_checkpoint(
state,
logger,
upload_to_save_folder=True,
register_to_mlflow=False,
)

# Clean up temporary save directory; all processes are done with it.
if self.temp_save_dir is not None:
shutil.rmtree(self.temp_save_dir)
Expand Down Expand Up @@ -339,12 +379,23 @@ def _is_last_batch(self, state: State):

return False

def _all_child_processes_done(self) -> bool:
not_done = any(process.is_alive() for process in self.child_processes)
x = torch.tensor(1 if not_done else 0).to(device='cuda')
def _all_register_processes_done(self, device: Device) -> bool:
not_done = any(
process.is_alive() for process in self.register_processes
)
x = device.tensor_to_device(torch.tensor(1 if not_done else 0))
dist.all_reduce(x, reduce_operation='MAX')
return x.item() == 0

def _any_register_processes_error(self, device: Device) -> bool:
has_errors = any(
process.exitcode is not None and process.exitcode != 0
for process in self.register_processes
)
x = device.tensor_to_device(torch.tensor(1 if has_errors else 0))
dist.all_reduce(x, reduce_operation='MAX')
return x.item() == 1

def transform_model_and_tokenizer(
self,
model: PreTrainedModel,
Expand Down Expand Up @@ -412,7 +463,21 @@ def transform_model_pre_registration(
"""
return model

def _save_checkpoint(self, state: State, logger: Logger):
def _save_checkpoint(
self,
state: State,
logger: Logger,
upload_to_save_folder: bool,
register_to_mlflow: bool,
):
"""Save a HuggingFace formatted checkpoint.
Args:
state (State): The training state.
logger (Logger): The logger.
upload_to_save_folder (bool): Whether to upload the HF checkpoint to the save folder.
register_to_mlflow (bool): Whether to register the model to MLFlow
"""
del logger # unused

self.last_checkpoint_batch = state.timestamp.batch
Expand Down Expand Up @@ -548,50 +613,53 @@ def tensor_hook(
].base_model_name_or_path = self.pretrained_model_name

log.debug('Saving Hugging Face checkpoint to disk')
# This context manager casts the TE extra state in io.BytesIO format to tensor format
# Needed for proper hf ckpt saving.
context_manager = te.onnx_export(
True,
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext(
)
with context_manager:
new_model_instance.save_pretrained(temp_save_dir)
if original_tokenizer is not None:
assert isinstance(
original_tokenizer,
PreTrainedTokenizerBase,
)
original_tokenizer.save_pretrained(temp_save_dir)

# Only need to edit files for MPT because it has custom code
if new_model_instance.config.model_type == 'mpt':
log.debug('Editing MPT files for HuggingFace compatibility')
edit_files_for_hf_compatibility(
temp_save_dir,
self.flatten_imports,
)

if self.remote_ud is not None:
for filename in os.listdir(temp_save_dir):
remote_file_name = os.path.join(save_dir, filename)
remote_file_uri = self.remote_ud.remote_backend.get_uri(
remote_file_name,
)
log.info(
f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}',
if upload_to_save_folder:
# This context manager casts the TE extra state in io.BytesIO format to tensor format
# Needed for proper hf ckpt saving.
context_manager = te.onnx_export(
True,
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext(
)
with context_manager:
new_model_instance.save_pretrained(temp_save_dir)
if original_tokenizer is not None:
assert isinstance(
original_tokenizer,
PreTrainedTokenizerBase,
)
self.remote_ud.upload_file(
state=state,
remote_file_name=remote_file_name,
file_path=Path(os.path.join(temp_save_dir, filename)),
overwrite=self.overwrite,
original_tokenizer.save_pretrained(temp_save_dir)

# Only need to edit files for MPT because it has custom code
if new_model_instance.config.model_type == 'mpt':
log.debug('Editing MPT files for HuggingFace compatibility')
edit_files_for_hf_compatibility(
temp_save_dir,
self.flatten_imports,
)

if self.remote_ud is not None:
for filename in os.listdir(temp_save_dir):
remote_file_name = os.path.join(save_dir, filename)
remote_file_uri = self.remote_ud.remote_backend.get_uri(
remote_file_name,
)
log.info(
f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}',
)
self.remote_ud.upload_file(
state=state,
remote_file_name=remote_file_name,
file_path=Path(
os.path.join(temp_save_dir, filename),
),
overwrite=self.overwrite,
)

dist.barrier()

if dist.get_global_rank() == 0:
if self.mlflow_registered_model_name and self._is_last_batch(state):

if register_to_mlflow:
new_model_instance = self.transform_model_pre_registration(
new_model_instance,
)
Expand Down Expand Up @@ -680,7 +748,7 @@ def tensor_hook(
# Restore the monitor process.
if monitor_process is not None:
mlflow_logger.monitor_process = monitor_process # type: ignore
self.child_processes.append(process)
self.register_processes.append(process)

# Save the temporary directory to be cleaned up later.
if use_temp_dir:
Expand Down
7 changes: 6 additions & 1 deletion llmfoundry/command_utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,12 @@ def train(cfg: DictConfig) -> Trainer:
)

hf_checkpointer_callback = hf_checkpointer_callbacks[0]
hf_checkpointer_callback._save_checkpoint(trainer.state, trainer.logger)
hf_checkpointer_callback._save_checkpoint(
trainer.state,
trainer.logger,
upload_to_save_folder=True,
register_to_mlflow=True,
)
return trainer

if train_cfg.only_composer_checkpoint:
Expand Down
Loading

0 comments on commit a862d6e

Please sign in to comment.