Skip to content

Commit

Permalink
Merge branch 'main' into tp
Browse files Browse the repository at this point in the history
  • Loading branch information
eitanturok authored Sep 25, 2024
2 parents 3372ec0 + 722526d commit c09223c
Show file tree
Hide file tree
Showing 26 changed files with 801 additions and 107 deletions.
2 changes: 1 addition & 1 deletion llmfoundry/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

"""The LLM Foundry Version."""

__version__ = '0.12.0.dev0'
__version__ = '0.13.0.dev0'
168 changes: 119 additions & 49 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import shutil
import tempfile
import time
import warnings
from multiprocessing.context import SpawnProcess
from pathlib import Path
from typing import Any, Optional, Sequence, Union
Expand All @@ -18,6 +19,7 @@
import torch
import torch.nn as nn
from composer.core import Callback, Event, Precision, State, Time, TimeUnit
from composer.devices import Device
from composer.loggers import Logger, MLFlowLogger
from composer.models import HuggingFaceModel
from composer.utils import (
Expand Down Expand Up @@ -161,6 +163,10 @@ class HuggingFaceCheckpointer(Callback):
keys ``input_example`` and ``signature``.
flatten_imports (Sequence[str]): A sequence of import prefixes that will
be flattened when editing MPT files.
final_register_only (bool): If true, only register the model in the MLFlow
registry on the last batch and do not save the HuggingFace checkpoint. If
registration fails or mlflow_registered_model_name is not set, then we will
fallback to saving the HuggingFace checkpoint.
"""

def __init__(
Expand All @@ -173,6 +179,7 @@ def __init__(
mlflow_registered_model_name: Optional[str] = None,
mlflow_logging_config: Optional[dict] = None,
flatten_imports: Sequence[str] = ('llmfoundry',),
final_register_only: bool = False,
):
_, _, self.save_dir_format_str = parse_uri(save_folder)
self.overwrite = overwrite
Expand All @@ -185,8 +192,18 @@ def __init__(
self.flatten_imports = flatten_imports
self.using_peft = False

# mlflow config setup
self.final_register_only = final_register_only

self.mlflow_registered_model_name = mlflow_registered_model_name
if self.final_register_only and self.mlflow_registered_model_name is None:
self.final_register_only = False
warnings.warn(
'final_register_only is set to True, but mlflow_registered_model_name is not set. '
+
f'Defaulting to final_register_only=False and saving the HuggingFace checkpoint to {save_folder=}.',
)

# mlflow config setup
if mlflow_logging_config is None:
mlflow_logging_config = {}
if self.mlflow_registered_model_name is not None:
Expand Down Expand Up @@ -249,7 +266,7 @@ def __init__(
self.last_checkpoint_batch: Optional[Time] = None
self.mlflow_loggers = []

self.child_processes: list[SpawnProcess] = []
self.register_processes: list[SpawnProcess] = []
# Temporary save directory used by child_processes.
self.temp_save_dir = None

Expand All @@ -259,7 +276,17 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
state,
event,
) and self.last_checkpoint_batch != state.timestamp.batch:
self._save_checkpoint(state, logger)
is_last_batch = self._is_last_batch(state)
self._save_checkpoint(
state,
logger,
register_to_mlflow=(
self.mlflow_registered_model_name is not None and
is_last_batch
),
upload_to_save_folder=not self.final_register_only or
not is_last_batch,
)
elif event == Event.INIT:
if not isinstance(state.model, HuggingFaceModel):
raise ValueError(
Expand Down Expand Up @@ -300,14 +327,27 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
# Wait for all child processes spawned by the callback to finish.
timeout = 3600
wait_start = time.time()
while not self._all_child_processes_done():
while not self._all_register_processes_done(state.device):
wait_time = time.time() - wait_start
if wait_time > timeout:
raise TimeoutError(
f'Waited {wait_time} seconds for child processes to complete. Exceeded timeout of {timeout} seconds.',
)
time.sleep(2)

if self._any_register_processes_error(
state.device,
) and self.final_register_only:
log.error(
'An error occurred in one or more registration processes. Fallback to saving the HuggingFace checkpoint.',
)
self._save_checkpoint(
state,
logger,
upload_to_save_folder=True,
register_to_mlflow=False,
)

# Clean up temporary save directory; all processes are done with it.
if self.temp_save_dir is not None:
shutil.rmtree(self.temp_save_dir)
Expand Down Expand Up @@ -339,12 +379,23 @@ def _is_last_batch(self, state: State):

return False

def _all_child_processes_done(self) -> bool:
not_done = any(process.is_alive() for process in self.child_processes)
x = torch.tensor(1 if not_done else 0).to(device='cuda')
def _all_register_processes_done(self, device: Device) -> bool:
not_done = any(
process.is_alive() for process in self.register_processes
)
x = device.tensor_to_device(torch.tensor(1 if not_done else 0))
dist.all_reduce(x, reduce_operation='MAX')
return x.item() == 0

def _any_register_processes_error(self, device: Device) -> bool:
has_errors = any(
process.exitcode is not None and process.exitcode != 0
for process in self.register_processes
)
x = device.tensor_to_device(torch.tensor(1 if has_errors else 0))
dist.all_reduce(x, reduce_operation='MAX')
return x.item() == 1

def transform_model_and_tokenizer(
self,
model: PreTrainedModel,
Expand Down Expand Up @@ -412,7 +463,21 @@ def transform_model_pre_registration(
"""
return model

def _save_checkpoint(self, state: State, logger: Logger):
def _save_checkpoint(
self,
state: State,
logger: Logger,
upload_to_save_folder: bool,
register_to_mlflow: bool,
):
"""Save a HuggingFace formatted checkpoint.
Args:
state (State): The training state.
logger (Logger): The logger.
upload_to_save_folder (bool): Whether to upload the HF checkpoint to the save folder.
register_to_mlflow (bool): Whether to register the model to MLFlow
"""
del logger # unused

self.last_checkpoint_batch = state.timestamp.batch
Expand Down Expand Up @@ -520,11 +585,13 @@ def tensor_hook(
new_base_model_instance,
original_model.peft_config[active_adapter],
)
del new_base_model_instance
else:
new_model_instance = type(original_model)(new_config)
new_model_instance.generation_config.update(
**original_model.generation_config.to_dict(),
)
if new_model_instance.generation_config is not None:
new_model_instance.generation_config.update(
**original_model.generation_config.to_dict(),
)

# Then load the state dict in with "assign" so that the state dict
# is loaded properly even though the model is initially on meta device.
Expand All @@ -548,50 +615,53 @@ def tensor_hook(
].base_model_name_or_path = self.pretrained_model_name

log.debug('Saving Hugging Face checkpoint to disk')
# This context manager casts the TE extra state in io.BytesIO format to tensor format
# Needed for proper hf ckpt saving.
context_manager = te.onnx_export(
True,
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext(
)
with context_manager:
new_model_instance.save_pretrained(temp_save_dir)
if original_tokenizer is not None:
assert isinstance(
original_tokenizer,
PreTrainedTokenizerBase,
)
original_tokenizer.save_pretrained(temp_save_dir)

# Only need to edit files for MPT because it has custom code
if new_model_instance.config.model_type == 'mpt':
log.debug('Editing MPT files for HuggingFace compatibility')
edit_files_for_hf_compatibility(
temp_save_dir,
self.flatten_imports,
)

if self.remote_ud is not None:
for filename in os.listdir(temp_save_dir):
remote_file_name = os.path.join(save_dir, filename)
remote_file_uri = self.remote_ud.remote_backend.get_uri(
remote_file_name,
)
log.info(
f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}',
if upload_to_save_folder:
# This context manager casts the TE extra state in io.BytesIO format to tensor format
# Needed for proper hf ckpt saving.
context_manager = te.onnx_export(
True,
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext(
)
with context_manager:
new_model_instance.save_pretrained(temp_save_dir)
if original_tokenizer is not None:
assert isinstance(
original_tokenizer,
PreTrainedTokenizerBase,
)
self.remote_ud.upload_file(
state=state,
remote_file_name=remote_file_name,
file_path=Path(os.path.join(temp_save_dir, filename)),
overwrite=self.overwrite,
original_tokenizer.save_pretrained(temp_save_dir)

# Only need to edit files for MPT because it has custom code
if new_model_instance.config.model_type == 'mpt':
log.debug('Editing MPT files for HuggingFace compatibility')
edit_files_for_hf_compatibility(
temp_save_dir,
self.flatten_imports,
)

if self.remote_ud is not None:
for filename in os.listdir(temp_save_dir):
remote_file_name = os.path.join(save_dir, filename)
remote_file_uri = self.remote_ud.remote_backend.get_uri(
remote_file_name,
)
log.info(
f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}',
)
self.remote_ud.upload_file(
state=state,
remote_file_name=remote_file_name,
file_path=Path(
os.path.join(temp_save_dir, filename),
),
overwrite=self.overwrite,
)

dist.barrier()

if dist.get_global_rank() == 0:
if self.mlflow_registered_model_name and self._is_last_batch(state):

if register_to_mlflow:
new_model_instance = self.transform_model_pre_registration(
new_model_instance,
)
Expand Down Expand Up @@ -680,7 +750,7 @@ def tensor_hook(
# Restore the monitor process.
if monitor_process is not None:
mlflow_logger.monitor_process = monitor_process # type: ignore
self.child_processes.append(process)
self.register_processes.append(process)

# Save the temporary directory to be cleaned up later.
if use_temp_dir:
Expand Down
7 changes: 7 additions & 0 deletions llmfoundry/command_utils/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ClusterDoesNotExistError,
FailedToConnectToDatabricksError,
FailedToCreateSQLConnectionError,
InsufficientPermissionsError,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -454,6 +455,12 @@ def fetch(
sparkSession,
)
except Exception as e:
from pyspark.errors import AnalysisException
if isinstance(e, AnalysisException):
if 'INSUFFICIENT_PERMISSIONS' in e.message: # pyright: ignore
raise InsufficientPermissionsError(
action=f'reading from {tablename}',
) from e
raise RuntimeError(
f'Error in get rows from {tablename}. Restart sparkSession and try again',
) from e
Expand Down
19 changes: 13 additions & 6 deletions llmfoundry/command_utils/data_prep/convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
CannotUnicodeDecodeFile,
DatasetTooSmallError,
InputFolderMissingDataError,
InputFolderNotFound,
OutputFolderNotEmptyError,
)

Expand Down Expand Up @@ -125,11 +126,15 @@ def get_object_names(input_folder: str) -> list[str]:
object_store = maybe_create_object_store_from_uri(input_folder)
if object_store is not None:
_, _, folder_prefix = parse_uri(input_folder)
names = [
name for name in object_store.list_objects(folder_prefix)
if name.endswith('.txt')
]
log.info(f'Found {len(names)} text files in remote storage')
try:
names = [
name for name in object_store.list_objects(folder_prefix)
if name.endswith('.txt')
]
log.info(f'Found {len(names)} text files in remote storage')
except FileNotFoundError:
raise InputFolderNotFound(folder_prefix)

else:
# input_folder is a local folder
names = [
Expand Down Expand Up @@ -478,7 +483,9 @@ def convert_text_to_mds(
index_path = os.path.join(local_output_folder, 'index.json')
with open(index_path, 'r') as index_file:
if not json.load(index_file)['shards']:
raise DatasetTooSmallError()
raise DatasetTooSmallError(
reason='No shards were created when converting text to MDS.',
)

# Write a done file with the args and object names
write_done_file(local_output_folder, args_str, object_names)
Expand Down
Loading

0 comments on commit c09223c

Please sign in to comment.