Skip to content

Commit

Permalink
Merge branch 'main' into change_gauntlet_avging
Browse files Browse the repository at this point in the history
  • Loading branch information
bmosaicml authored Oct 24, 2023
2 parents 1fab1f2 + 091ddca commit 94b5438
Show file tree
Hide file tree
Showing 36 changed files with 2,269 additions and 560 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ You can select a specific commit hash such as `mosaicml/llm-foundry:1.13.1_cu117
|-------------------------------------------------------------|----------------|--------------|-------------------------------------|
| `mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04` | 1.13.1 | 11.7 | No |
| `mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04` | 2.0.1 | 11.8 | No |
| `mosaicml/pytorch:2.0.1_cu121-python3.10-ubuntu20.04` | 2.1.0 | 12.1 | No |
| `mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04` | 2.1.0 | 12.1 | No |
| `mosaicml/llm-foundry:1.13.1_cu117-latest` | 1.13.1 | 11.7 | Yes |
| `mosaicml/llm-foundry:2.0.1_cu118-latest` | 2.0.1 | 11.8 | Yes |
| `mosaicml/llm-foundry:2.1.0_cu121-latest` | 2.1.0 | 12.1 | Yes (flash attention v1) |
Expand Down
121 changes: 16 additions & 105 deletions llmfoundry/callbacks/generate_callback.py
Original file line number Diff line number Diff line change
@@ -1,119 +1,30 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Periodically log generations to wandb from a set of prompts."""
from typing import Any, List, Union, cast
"""Deprecated Generate callback.
import torch
import wandb
from composer.core import Callback, State, get_precision_context
from composer.loggers import Logger, WandBLogger
from composer.utils import dist, ensure_tuple
Please use composer.callbacks.Generate instead.
"""
import warnings
from typing import Any, List, Union

from composer.callbacks import Generate as ComposerGenerate
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]


class Generate(Callback):
class Generate(ComposerGenerate):

def __init__(self, prompts: List[str], batch_log_interval: int,
**kwargs: Any):
"""Periodically log generations to wandb from a set of prompts.
In the main view for a run, there will be a table that will show the _last_ logged generations.
To compare previous iterations of the generations, you need to
1. Click on the run
2. Click on "artifacts" in the menu on the left side of the screen
3. Click on one of the artifacts called "predictions"
4. Click on the "files" tab
5. Click on "predictions.table.json"
6. On the left hand side, there are different versions of the table produced throughout training. Select one of these.
7. Now, when you hover over other versions, there will be a "compare" button, which will allow you to compare the currently
selected version to the version you add via compare.
Args:
prompts (List[str]): The list of prompts you would like to produce generations for
batch_log_interval (int): The interval (in batches) at which this callback runs
kwargs: All kwargs well be passed along to the call to generate. This is for things like `do_sample`, `top_p`, etc
"""
self.prompts = prompts
self.batch_log_interval = batch_log_interval
self.generate_kwargs = kwargs
self.wandb_logger = None

def init(self, state: State, logger: Logger):
if dist.get_global_rank() == 0:
for destination in ensure_tuple(logger.destinations):
if isinstance(destination, WandBLogger):
self.wandb_logger = destination

def batch_checkpoint(self, state: State, logger: Logger) -> None:
if (state.timestamp.batch.value % self.batch_log_interval) == 0:
self.generate(state, logger)

def generate(self, state: State, logger: Logger) -> None:
model = state.model
original_mode = model.training
model.eval()
tokenizer = cast(Tokenizer, state.model.tokenizer)
device = state.device

if not hasattr(model.model, 'generate'):
raise ValueError(
f'Cannot generate from model {model.model.__class__.__name__} because it does not have a `generate` method'
)

# stash the original original value of padding_side because generation requires left padding
original_padding_side = tokenizer.padding_side
tokenizer.padding_side = 'left'
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenized_input = tokenizer(self.prompts,
return_tensors='pt',
padding=True)

for k, v in tokenized_input.items():
tokenized_input[k] = device.tensor_to_device(v)

# dummy forward call needed for FSDP to work consistently
dummy_input = torch.tensor([[0]], dtype=torch.long)
dummy_input = device.tensor_to_device(dummy_input)
with get_precision_context(state.precision):
with torch.no_grad():
assert isinstance(model.model, torch.nn.Module)
_ = model.model(input_ids=dummy_input)

output_token_ids = model.model.generate( # type: ignore
input_ids=tokenized_input['input_ids'],
attention_mask=tokenized_input['attention_mask'],
synced_gpus=True,
**self.generate_kwargs,
)

if dist.get_global_rank() == 0:
if self.wandb_logger is not None:
assert wandb.run is not None, 'wandb should have started run'

artifact = wandb.Artifact('generate_samples_' +
str(wandb.run.id),
type='predictions')

rows = []
for i in range(len(self.prompts)):
prompt = self.prompts[i]
output_tokens = output_token_ids[i][
tokenized_input['input_ids'].shape[1]:]
output_text = tokenizer.decode(output_tokens,
skip_special_tokens=True)

rows.append([prompt, output_text])

text_table = wandb.Table(data=rows,
columns=['prompt', 'generation'])
artifact.add(text_table, 'predictions')
wandb.log_artifact(artifact)
wandb.log({'generations': text_table},
step=state.timestamp.batch.value)
warnings.warn(
('Accessing llmfoundry.callbacks.generate_callback.Generate '
'is deprecated and will be removed in a future release. '
'Please use composer.callbacks.Generate instead.'),
DeprecationWarning,
)

tokenizer.padding_side = original_padding_side
model.train(mode=original_mode)
interval = f'{batch_log_interval}ba'
super().__init__(prompts=prompts, interval=interval, **kwargs)
149 changes: 118 additions & 31 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@
# SPDX-License-Identifier: Apache-2.0

import contextlib
import json
import copy
import logging
import os
import tempfile
from pathlib import Path
from typing import Optional, Union

import torch
from composer.callbacks.utils import create_interval_scheduler
from composer.core import Callback, Event, State, Time
from composer.core.state import fsdp_state_dict_type_context
from composer.loggers import Logger
from composer.loggers import Logger, MLFlowLogger
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
from composer.models import HuggingFaceModel
from composer.utils import dist, format_name_with_dist_and_time, parse_uri
from transformers import PreTrainedTokenizerBase
from composer.utils.misc import create_interval_scheduler
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.utils.huggingface_hub_utils import \
Expand All @@ -39,6 +39,11 @@ class HuggingFaceCheckpointer(Callback):
huggingface_folder_name (str): Folder to save each checkpoint under (can be a format string). Default is ``ba{batch}``.
precision: The precision to save the model in. Default is ``float32``. Options are ``bfloat16``, ``float16``, or ``float32``.
overwrite (bool): Whether to overwrite previous checkpoints.
mlflow_registered_model_name (Optional[str]): The name to register the model under in the MLflow model registry. If ``None``, the model will not
be registered. Default is ``None``.
mlflow_logging_config (Optional[dict]): A dictionary of config arguments that will get passed along to the MLflow ``save_model`` call.
Expected to contain ``metadata`` and ``task`` keys. If either is unspecified, the defaults are ``'text-generation'`` and
``{'task': 'llm/v1/completions'}`` respectively.
"""

def __init__(
Expand All @@ -48,6 +53,8 @@ def __init__(
huggingface_folder_name: str = 'ba{batch}',
precision: str = 'float32',
overwrite: bool = False,
mlflow_registered_model_name: Optional[str] = None,
mlflow_logging_config: Optional[dict] = None,
):
self.backend, self.bucket_name, self.save_dir_format_str = parse_uri(
save_folder)
Expand All @@ -58,6 +65,22 @@ def __init__(
'float16': torch.float16,
'bfloat16': torch.bfloat16,
}[precision]

# mlflow config setup
self.mlflow_registered_model_name = mlflow_registered_model_name
if mlflow_logging_config is None:
mlflow_logging_config = {}
if self.mlflow_registered_model_name is not None:
# Both the metadata and the task are needed in order for mlflow
# and databricks optimized model serving to work
if 'metadata' not in mlflow_logging_config:
mlflow_logging_config['metadata'] = {
'task': 'llm/v1/completions'
}
if 'task' not in mlflow_logging_config:
mlflow_logging_config['task'] = 'text-generation'
self.mlflow_logging_config = mlflow_logging_config

self.huggingface_folder_name_fstr = os.path.join(
'huggingface', huggingface_folder_name)
self.check_interval = create_interval_scheduler(
Expand All @@ -71,6 +94,7 @@ def __init__(
self.remote_ud = None

self.last_checkpoint_batch: Optional[Time] = None
self.mlflow_loggers = []

def run_event(self, event: Event, state: State, logger: Logger) -> None:
# The interval scheduler handles only returning True for the appropriate events
Expand All @@ -87,6 +111,23 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
self.remote_ud.init(state, logger)
state.callbacks.append(self.remote_ud)

if self.mlflow_registered_model_name is not None:
self.mlflow_loggers = [
logger_destination
for logger_destination in logger.destinations
if isinstance(logger_destination, MLFlowLogger)
]
if len(self.mlflow_loggers) == 0:
raise ValueError(
f'`mlflow_registered_model_name` was set, but no `MLFlowLogger` was found in the `logger.destinations` list. '
+
'Please add an `MLFlowLogger` or set `mlflow_registered_model_name` to `None`.'
)

import mlflow
mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set(
'5GB')

def _save_checkpoint(self, state: State, logger: Logger):
del logger # unused

Expand All @@ -99,8 +140,6 @@ def _save_checkpoint(self, state: State, logger: Logger):
MPTConfig.register_for_auto_class()
MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM')

assert isinstance(state.model, HuggingFaceModel)

save_dir = format_name_with_dist_and_time(
str(
Path(self.save_dir_format_str) /
Expand All @@ -114,44 +153,65 @@ def _save_checkpoint(self, state: State, logger: Logger):
assert isinstance(temp_save_dir,
str) # pyright doesn't know about enter_result

with fsdp_state_dict_type_context(state.model.model,
state_dict_type='full'):
state_dict = state.model.model.state_dict()
log.debug('Gathering state dict')
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

if state.is_model_ddp:
original_model: PreTrainedModel = state.model.module.model
state_dict_model = state.model.module.model
original_tokenizer = state.model.module.tokenizer
elif isinstance(state.model.model, FSDP):
original_model: PreTrainedModel = state.model.model.module
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer
else:
original_model: PreTrainedModel = state.model.model
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer

state_dict_context = fsdp_state_dict_type_context(
original_model, state_dict_type='full') if (
(not state.is_model_ddp) and isinstance(
state_dict_model, FSDP)) else contextlib.nullcontext()

with state_dict_context:
state_dict = state_dict_model.state_dict()

# convert the state dict to the requested precision
for k, v in state_dict.items():
if isinstance(v, torch.Tensor):
state_dict[k] = v.to(dtype=self.dtype)

if dist.get_global_rank() == 0:
# We raise above if the model is not a HuggingFaceModel, so this assert is safe
assert hasattr(state.model.model, 'save_pretrained')
state.model.model.save_pretrained(temp_save_dir,
state_dict=state_dict)

if state.model.tokenizer is not None:
assert isinstance(state.model.tokenizer,
log.debug('Saving Hugging Face checkpoint to disk')

copied_config = copy.deepcopy(original_model.config)
if copied_config.model_type == 'mpt':
copied_config.attn_config['attn_impl'] = 'torch'
copied_config.init_device = 'cpu'

# TODO: after torch 2.1, we can load a state dict into a meta model
# and skip the extra model init
log.debug(f'Creating new model instance')
new_model_instance = type(original_model)(copied_config)
new_model_instance.to(dtype=self.dtype)
new_model_instance.load_state_dict(state_dict)
del state_dict

log.debug('Saving Hugging Face checkpoint to disk')
new_model_instance.save_pretrained(temp_save_dir)
if original_tokenizer is not None:
assert isinstance(original_tokenizer,
PreTrainedTokenizerBase)
state.model.tokenizer.save_pretrained(temp_save_dir)
original_tokenizer.save_pretrained(temp_save_dir)

# Only need to edit files for MPT because it has custom code
if state.model.model.config.model_type == 'mpt':
if original_model.config.model_type == 'mpt':
log.debug('Editing MPT files for HuggingFace compatibility')
edit_files_for_hf_compatibility(temp_save_dir)

with open(os.path.join(temp_save_dir, 'config.json'), 'r') as f:
edited_config = json.load(f)

if state.model.model.config.model_type == 'mpt':
edited_config['attn_config']['attn_impl'] = 'torch'
edited_config['init_device'] = 'cpu'

edited_config['torch_dtype'] = self.precision
with open(os.path.join(temp_save_dir, 'config.json'), 'w') as f:
json.dump(edited_config, f, indent=4)

if self.upload_to_object_store:
assert self.remote_ud is not None
# TODO change to log after other pr
log.info(
f'Uploading HuggingFace formatted checkpoint to {self.backend}://{self.bucket_name}/{save_dir}'
)
Expand All @@ -164,4 +224,31 @@ def _save_checkpoint(self, state: State, logger: Logger):
overwrite=self.overwrite,
)

dist.barrier()
elapsed_duration = state.get_elapsed_duration()
if self.mlflow_registered_model_name is not None and elapsed_duration is not None and elapsed_duration >= 1.0:
components = {'model': new_model_instance}
if original_tokenizer is not None:
components['tokenizer'] = original_tokenizer

log.debug('Logging Hugging Face model to MLFlow')
for i, mlflow_logger in enumerate(self.mlflow_loggers):
log.debug(
f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}'
)
local_save_path = str(
Path(temp_save_dir) / f'mlflow_save_{i}')

# TODO: Remove after mlflow fixes the bug that makes this necessary
import mlflow
mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''
mlflow_logger.save_model(
flavor='transformers',
transformers_model=components,
path=local_save_path,
**self.mlflow_logging_config,
)
mlflow_logger.register_model(
model_uri=local_save_path,
name=self.mlflow_registered_model_name,
await_registration_for=3600,
)
Loading

0 comments on commit 94b5438

Please sign in to comment.