Skip to content

Commit

Permalink
Merge branch 'main' into chuck/add_profiler_flags
Browse files Browse the repository at this point in the history
  • Loading branch information
j316chuck authored Oct 17, 2023
2 parents 7a6dc9d + cc238a3 commit 967b4d2
Show file tree
Hide file tree
Showing 9 changed files with 440 additions and 55 deletions.
149 changes: 118 additions & 31 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@
# SPDX-License-Identifier: Apache-2.0

import contextlib
import json
import copy
import logging
import os
import tempfile
from pathlib import Path
from typing import Optional, Union

import torch
from composer.callbacks.utils import create_interval_scheduler
from composer.core import Callback, Event, State, Time
from composer.core.state import fsdp_state_dict_type_context
from composer.loggers import Logger
from composer.loggers import Logger, MLFlowLogger
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
from composer.models import HuggingFaceModel
from composer.utils import dist, format_name_with_dist_and_time, parse_uri
from transformers import PreTrainedTokenizerBase
from composer.utils.misc import create_interval_scheduler
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.utils.huggingface_hub_utils import \
Expand All @@ -39,6 +39,11 @@ class HuggingFaceCheckpointer(Callback):
huggingface_folder_name (str): Folder to save each checkpoint under (can be a format string). Default is ``ba{batch}``.
precision: The precision to save the model in. Default is ``float32``. Options are ``bfloat16``, ``float16``, or ``float32``.
overwrite (bool): Whether to overwrite previous checkpoints.
mlflow_registered_model_name (Optional[str]): The name to register the model under in the MLflow model registry. If ``None``, the model will not
be registered. Default is ``None``.
mlflow_logging_config (Optional[dict]): A dictionary of config arguments that will get passed along to the MLflow ``save_model`` call.
Expected to contain ``metadata`` and ``task`` keys. If either is unspecified, the defaults are ``'text-generation'`` and
``{'task': 'llm/v1/completions'}`` respectively.
"""

def __init__(
Expand All @@ -48,6 +53,8 @@ def __init__(
huggingface_folder_name: str = 'ba{batch}',
precision: str = 'float32',
overwrite: bool = False,
mlflow_registered_model_name: Optional[str] = None,
mlflow_logging_config: Optional[dict] = None,
):
self.backend, self.bucket_name, self.save_dir_format_str = parse_uri(
save_folder)
Expand All @@ -58,6 +65,22 @@ def __init__(
'float16': torch.float16,
'bfloat16': torch.bfloat16,
}[precision]

# mlflow config setup
self.mlflow_registered_model_name = mlflow_registered_model_name
if mlflow_logging_config is None:
mlflow_logging_config = {}
if self.mlflow_registered_model_name is not None:
# Both the metadata and the task are needed in order for mlflow
# and databricks optimized model serving to work
if 'metadata' not in mlflow_logging_config:
mlflow_logging_config['metadata'] = {
'task': 'llm/v1/completions'
}
if 'task' not in mlflow_logging_config:
mlflow_logging_config['task'] = 'text-generation'
self.mlflow_logging_config = mlflow_logging_config

self.huggingface_folder_name_fstr = os.path.join(
'huggingface', huggingface_folder_name)
self.check_interval = create_interval_scheduler(
Expand All @@ -71,6 +94,7 @@ def __init__(
self.remote_ud = None

self.last_checkpoint_batch: Optional[Time] = None
self.mlflow_loggers = []

def run_event(self, event: Event, state: State, logger: Logger) -> None:
# The interval scheduler handles only returning True for the appropriate events
Expand All @@ -87,6 +111,23 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
self.remote_ud.init(state, logger)
state.callbacks.append(self.remote_ud)

if self.mlflow_registered_model_name is not None:
self.mlflow_loggers = [
logger_destination
for logger_destination in logger.destinations
if isinstance(logger_destination, MLFlowLogger)
]
if len(self.mlflow_loggers) == 0:
raise ValueError(
f'`mlflow_registered_model_name` was set, but no `MLFlowLogger` was found in the `logger.destinations` list. '
+
'Please add an `MLFlowLogger` or set `mlflow_registered_model_name` to `None`.'
)

import mlflow
mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set(
'5GB')

def _save_checkpoint(self, state: State, logger: Logger):
del logger # unused

Expand All @@ -99,8 +140,6 @@ def _save_checkpoint(self, state: State, logger: Logger):
MPTConfig.register_for_auto_class()
MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM')

assert isinstance(state.model, HuggingFaceModel)

save_dir = format_name_with_dist_and_time(
str(
Path(self.save_dir_format_str) /
Expand All @@ -114,44 +153,65 @@ def _save_checkpoint(self, state: State, logger: Logger):
assert isinstance(temp_save_dir,
str) # pyright doesn't know about enter_result

with fsdp_state_dict_type_context(state.model.model,
state_dict_type='full'):
state_dict = state.model.model.state_dict()
log.debug('Gathering state dict')
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

if state.is_model_ddp:
original_model: PreTrainedModel = state.model.module.model
state_dict_model = state.model.module.model
original_tokenizer = state.model.module.tokenizer
elif isinstance(state.model.model, FSDP):
original_model: PreTrainedModel = state.model.model.module
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer
else:
original_model: PreTrainedModel = state.model.model
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer

state_dict_context = fsdp_state_dict_type_context(
original_model, state_dict_type='full') if (
(not state.is_model_ddp) and isinstance(
state_dict_model, FSDP)) else contextlib.nullcontext()

with state_dict_context:
state_dict = state_dict_model.state_dict()

# convert the state dict to the requested precision
for k, v in state_dict.items():
if isinstance(v, torch.Tensor):
state_dict[k] = v.to(dtype=self.dtype)

if dist.get_global_rank() == 0:
# We raise above if the model is not a HuggingFaceModel, so this assert is safe
assert hasattr(state.model.model, 'save_pretrained')
state.model.model.save_pretrained(temp_save_dir,
state_dict=state_dict)

if state.model.tokenizer is not None:
assert isinstance(state.model.tokenizer,
log.debug('Saving Hugging Face checkpoint to disk')

copied_config = copy.deepcopy(original_model.config)
if copied_config.model_type == 'mpt':
copied_config.attn_config['attn_impl'] = 'torch'
copied_config.init_device = 'cpu'

# TODO: after torch 2.1, we can load a state dict into a meta model
# and skip the extra model init
log.debug(f'Creating new model instance')
new_model_instance = type(original_model)(copied_config)
new_model_instance.to(dtype=self.dtype)
new_model_instance.load_state_dict(state_dict)
del state_dict

log.debug('Saving Hugging Face checkpoint to disk')
new_model_instance.save_pretrained(temp_save_dir)
if original_tokenizer is not None:
assert isinstance(original_tokenizer,
PreTrainedTokenizerBase)
state.model.tokenizer.save_pretrained(temp_save_dir)
original_tokenizer.save_pretrained(temp_save_dir)

# Only need to edit files for MPT because it has custom code
if state.model.model.config.model_type == 'mpt':
if original_model.config.model_type == 'mpt':
log.debug('Editing MPT files for HuggingFace compatibility')
edit_files_for_hf_compatibility(temp_save_dir)

with open(os.path.join(temp_save_dir, 'config.json'), 'r') as f:
edited_config = json.load(f)

if state.model.model.config.model_type == 'mpt':
edited_config['attn_config']['attn_impl'] = 'torch'
edited_config['init_device'] = 'cpu'

edited_config['torch_dtype'] = self.precision
with open(os.path.join(temp_save_dir, 'config.json'), 'w') as f:
json.dump(edited_config, f, indent=4)

if self.upload_to_object_store:
assert self.remote_ud is not None
# TODO change to log after other pr
log.info(
f'Uploading HuggingFace formatted checkpoint to {self.backend}://{self.bucket_name}/{save_dir}'
)
Expand All @@ -164,4 +224,31 @@ def _save_checkpoint(self, state: State, logger: Logger):
overwrite=self.overwrite,
)

dist.barrier()
elapsed_duration = state.get_elapsed_duration()
if self.mlflow_registered_model_name is not None and elapsed_duration is not None and elapsed_duration >= 1.0:
components = {'model': new_model_instance}
if original_tokenizer is not None:
components['tokenizer'] = original_tokenizer

log.debug('Logging Hugging Face model to MLFlow')
for i, mlflow_logger in enumerate(self.mlflow_loggers):
log.debug(
f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}'
)
local_save_path = str(
Path(temp_save_dir) / f'mlflow_save_{i}')

# TODO: Remove after mlflow fixes the bug that makes this necessary
import mlflow
mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''
mlflow_logger.save_model(
flavor='transformers',
transformers_model=components,
path=local_save_path,
**self.mlflow_logging_config,
)
mlflow_logger.register_model(
model_uri=local_save_path,
name=self.mlflow_registered_model_name,
await_registration_for=3600,
)
18 changes: 14 additions & 4 deletions llmfoundry/data/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@

import numpy as np
import torch
from composer.core.data_spec import DataSpec
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerBase

from llmfoundry.data.packing import BinPackWrapper
from llmfoundry.data.text_data import StreamingTextDataset
from llmfoundry.data.text_data import (StreamingTextDataset,
get_tokens_per_batch_func)
from llmfoundry.models import utils

__all__ = ['MixtureOfDenoisersCollator', 'build_text_denoising_dataloader']
Expand Down Expand Up @@ -353,7 +355,7 @@ def build_text_denoising_dataloader(
cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int,
) -> DataLoader[Dict]:
) -> DataSpec:
"""Constructor function for a Mixture of Denoisers dataloader.
This function constructs a dataloader that can be used to train an
Expand Down Expand Up @@ -506,7 +508,7 @@ def build_text_denoising_dataloader(
'but cfg.dataset.packing_ratio has not been set. Please set ' +\
'the latter to turn on packing or remove the former from the config.')

return DataLoader(
dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=device_batch_size,
Expand All @@ -518,6 +520,12 @@ def build_text_denoising_dataloader(
timeout=cfg.get('timeout', 0),
)

token_counting_func = get_tokens_per_batch_func(
pad_token_id=tokenizer.pad_token_id,
decoder_only=cfg.mixture_of_denoisers.decoder_only_format)

return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)


def noise_token_sequence(
example: Union[torch.Tensor, Mapping[str, Any]],
Expand Down Expand Up @@ -869,7 +877,9 @@ def _format_tokens_for_decoder_only(
tokenizer = build_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_kwargs=tokenizer_kwargs)

loader = build_text_denoising_dataloader(cfg, tokenizer, device_batch_size)
loader = build_text_denoising_dataloader(cfg, tokenizer,
device_batch_size).dataloader
assert isinstance(loader, DataLoader)
assert isinstance(loader.dataset, StreamingTextDataset)

print(f'\n\nTRUNCATING TO: {loader.dataset.max_seq_len}\n\n')
Expand Down
16 changes: 12 additions & 4 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import datasets as hf_datasets
import torch
from composer.core.data_spec import DataSpec
from composer.utils import dist, get_file, parse_uri
from omegaconf import DictConfig
from torch.utils.data import DataLoader
Expand All @@ -14,6 +15,7 @@
from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator
from llmfoundry.data.finetuning.tasks import dataset_constructor
from llmfoundry.data.packing import BinPackWrapper
from llmfoundry.data.text_data import get_tokens_per_batch_func

log = logging.getLogger(__name__)

Expand All @@ -23,7 +25,7 @@

def build_finetuning_dataloader(cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int) -> DataLoader:
device_batch_size: int) -> DataSpec:
"""Builds a finetuning dataloader for training or evaluating.
The underlying dataset can be built through one of two code paths:
Expand Down Expand Up @@ -143,7 +145,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
collate_fn, dataloader_batch_size = _build_collate_fn(
cfg.dataset, tokenizer, device_batch_size)

return DataLoader(
dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
Expand Down Expand Up @@ -193,7 +195,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
)

assert dataset is not None
return DataLoader(
dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
Expand All @@ -208,6 +210,11 @@ def build_finetuning_dataloader(cfg: DictConfig,
timeout=cfg.get('timeout', 0),
)

token_counting_func = get_tokens_per_batch_func(
pad_token_id=tokenizer.pad_token_id)

return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)


def _validate_config(dataset_cfg: DictConfig) -> None:
"""Validates the dataset configuration.
Expand Down Expand Up @@ -442,7 +449,8 @@ def _build_collate_fn(
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)

device_batch_size = 2
dataloader = build_finetuning_dataloader(cfg, tokenizer, device_batch_size)
dataloader = build_finetuning_dataloader(cfg, tokenizer,
device_batch_size).dataloader

packing = cfg.dataset.get('packing_ratio') is not None

Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
dataloader_cfg.dataset.packing_ratio = None
dataloader_cfg.dataset.max_leftovers_to_keep = None
train_dataloader = build_dataloader(dataloader_cfg, tokenizer,
max(raw_batch_sizes) * 100)
max(raw_batch_sizes) * 100).dataloader

# Get a bunch of raw examples
big_batch = next(iter(train_dataloader))
Expand Down
Loading

0 comments on commit 967b4d2

Please sign in to comment.