Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
bmosaicml committed Feb 7, 2024
2 parents 53da3ea + 105f766 commit 16b8e32
Show file tree
Hide file tree
Showing 44 changed files with 13,781 additions and 1,106 deletions.
24 changes: 15 additions & 9 deletions TUTORIAL.md
Original file line number Diff line number Diff line change
Expand Up @@ -357,19 +357,25 @@ Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706.
| RoPE (Hugging<code>&nbsp;</code>Face Implementation) | <pre>model:<br> attn_config:<br> rope:&nbsp;True<br> rope_impl:&nbsp;hf</pre>| 62.3 | |

### Can I finetune using PEFT / LoRA?
- The LLM Foundry codebase does not directly have examples of PEFT or LORA workflows. However, our MPT model is a subclass of HuggingFace `PretrainedModel`, and https://github.com/mosaicml/llm-foundry/pull/346 added required features to enable HuggingFace’s [PEFT](https://huggingface.co/docs/peft/index) / [LORA](https://huggingface.co/docs/peft/conceptual_guides/lora) workflows for MPT. MPT models with LoRA modules can be trained either using LLM Foundry or Hugging Face's [accelerate](https://huggingface.co/docs/accelerate/index). Within LLM Foundry, run (`scripts/train/train.py`), adding `lora` arguments to the config `.yaml`, like so:
- LLM Foundry does support LoRA via an integration with the [PEFT](https://github.com/huggingface/peft) library. Within LLM Foundry, run (`scripts/train/train.py`), adding `peft_config` arguments to the `model` section of the config `.yaml`, like so:
<!--pytest.mark.skip-->
```yaml
lora:
args:
r: 16
lora_alpha: 32
lora_dropout: 0.05
target_modules: ['Wqkv']
model:
...
peft_config:
r: 16
peft_type: LORA
task_type: CAUSAL_LM
lora_alpha: 32
lora_dropout: 0.05
target_modules:
- q_proj
- k_proj
target_modules:
- 'Wqkv'
```
- In the current release, these features have Beta support.
- For efficiency, The MPT model concatenates the `Q`, `K`, and `V` matrices in each attention block into a single `Wqkv` matrix that is three times wider. Currently, LoRA supports a low-rank approximation to this `Wqkv` matrix.
- When evaluating with PEFT / LoRA seperated weight, just set `pretrained_lora_id_or_path` in `model`(Find an example [here](scripts/eval/yamls/hf_lora_eval.yml#L19)).
- When evaluating with PEFT / LoRA separated weight, just set `pretrained_lora_id_or_path` in `model`(Find an example [here](scripts/eval/yamls/hf_lora_eval.yml#L19)).

### Can I quantize these models and/or run on CPU?
- The LLM Foundry codebase does not directly have examples of quantization or limited-resource inference. But you can check out [GGML](https://github.com/ggerganov/ggml) (same library that powers llama.cpp) which has built support for efficiently running MPT models on CPU! You _can_ load your model in 8-bit precision for inference using the [bitsandbytes library](https://github.com/TimDettmers/bitsandbytes) and Hugging Face's [accelerate](https://huggingface.co/docs/accelerate/index) via `load model = AutoModelForCausalLM.from_pretrained(model_name, load_in_8bit=True, device_map="auto", trust_remote_code=True)`, although we have not extensively benchmarked the performance (see the Hugging Face [quantization documentation](https://huggingface.co/docs/transformers/main/main_classes/quantization) for more detail).
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,4 @@
'TiktokenTokenizerWrapper',
]

__version__ = '0.4.0'
__version__ = '0.5.0'
4 changes: 0 additions & 4 deletions llmfoundry/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
from llmfoundry.callbacks.async_eval_callback import AsyncEval
from llmfoundry.callbacks.eval_gauntlet_callback import EvalGauntlet
from llmfoundry.callbacks.fdiff_callback import FDiffMetrics
from llmfoundry.callbacks.generate_callback import Generate
from llmfoundry.callbacks.hf_checkpointer import HuggingFaceCheckpointer
from llmfoundry.callbacks.model_gauntlet_callback import ModelGauntlet
from llmfoundry.callbacks.monolithic_ckpt_callback import \
MonolithicCheckpointSaver
from llmfoundry.callbacks.resumption_callbacks import (GlobalLRScaling,
Expand All @@ -21,13 +19,11 @@

__all__ = [
'FDiffMetrics',
'Generate',
'MonolithicCheckpointSaver',
'GlobalLRScaling',
'LayerFreezing',
'ScheduledGarbageCollector',
'EvalGauntlet',
'ModelGauntlet',
'HuggingFaceCheckpointer',
'AsyncEval',
]
3 changes: 1 addition & 2 deletions llmfoundry/callbacks/eval_gauntlet_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ def __init__(self,
elif self.weighting == Weighting.SAMPLE_SZ:
weight = cumulative_samples
elif self.weighting == Weighting.LOG_SAMPLE_SZ:
weight = max(math.log(cumulative_samples, 2), 1)

weight = max(math.log2(cumulative_samples), 1)
assert weight is not None
benchmark['weighting'] = weight

Expand Down
30 changes: 0 additions & 30 deletions llmfoundry/callbacks/generate_callback.py

This file was deleted.

88 changes: 72 additions & 16 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import re
import tempfile
from pathlib import Path
from typing import Optional, Sequence, Union
from typing import Any, Dict, Optional, Sequence, Union

import torch
from composer.core import Callback, Event, State, Time, TimeUnit
Expand All @@ -20,6 +20,7 @@
maybe_create_remote_uploader_downloader_from_uri,
parse_uri)
from composer.utils.misc import create_interval_scheduler
from mlflow.transformers import _fetch_model_card, _write_license_information
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
Expand All @@ -32,17 +33,41 @@
_LICENSE_FILE_PATTERN = re.compile(r'license(\.[a-z]+|$)', re.IGNORECASE)


def _maybe_get_license_filename(local_dir: str) -> Optional[str]:
def _maybe_get_license_filename(
local_dir: str,
pretrained_model_name: Optional[str] = None) -> Optional[str]:
"""Returns the name of the license file if it exists in the local_dir.
Note: This is intended to be consistent with the code in MLflow.
https://github.com/mlflow/mlflow/blob/5d13d6ec620a02de9a5e31201bf1becdb9722ea5/mlflow/transformers/__init__.py#L1152
Since LLM Foundry supports local model files being used rather than fetching the files from the Hugging Face Hub,
MLflow's logic to fetch and write the license information on model save is not applicable; it will try to search for
a Hugging Face repo named after the local path. However, the user can provide the original pretrained model name,
in which case this function will use that to fetch the correct license information.
If the license file does not exist, returns None.
"""
try:
return next(file for file in os.listdir(local_dir)
if _LICENSE_FILE_PATTERN.search(file))
license_filename = next(file for file in os.listdir(local_dir)
if _LICENSE_FILE_PATTERN.search(file))

# If a pretrained model name is provided, replace the license file with the correct info from HF Hub.
if pretrained_model_name is not None:
log.info(
f'Overwriting license file {license_filename} with license info for model {pretrained_model_name} from Hugging Face Hub'
)
os.remove(os.path.join(local_dir, license_filename))
model_card = _fetch_model_card(pretrained_model_name)

local_dir_path = Path(local_dir).absolute()
_write_license_information(pretrained_model_name, model_card,
local_dir_path)
license_filename = next(file for file in os.listdir(local_dir)
if _LICENSE_FILE_PATTERN.search(file))

return license_filename

except StopIteration:
return None

Expand Down Expand Up @@ -203,14 +228,17 @@ def _save_checkpoint(self, state: State, logger: Logger):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

if state.is_model_ddp:
composer_model = state.model.module
original_model: PreTrainedModel = state.model.module.model
state_dict_model = state.model.module.model
original_tokenizer = state.model.module.tokenizer
elif isinstance(state.model.model, FSDP):
composer_model = state.model
original_model: PreTrainedModel = state.model.model.module
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer
else:
composer_model = state.model
original_model: PreTrainedModel = state.model.model
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer
Expand All @@ -237,10 +265,23 @@ def _save_checkpoint(self, state: State, logger: Logger):
copied_config.init_device = 'cpu'

log.debug(f'Creating new model instance')
# First create the model instance on meta device to avoid the
# initialization cost.
with init_empty_weights():
new_model_instance = type(original_model)(copied_config)

if composer_model.using_peft:
# We don't use meta here because the state dict does not contain the full
# model, only the adapter weights.
active_adapter = original_model.active_adapter
base_model = original_model.get_base_model()
new_base_model_instance = type(base_model)(copied_config)

new_model_instance = type(original_model)(
new_base_model_instance,
original_model.peft_config[active_adapter])
new_model_instance.to(dtype=self.dtype)
else:
# First create the model instance on meta device to avoid the
# initialization cost.
with init_empty_weights():
new_model_instance = type(original_model)(copied_config)

# Then load the state dict in with "assign" so that the state dict
# is loaded properly even though the model is initially on meta device.
Expand Down Expand Up @@ -295,15 +336,30 @@ def _save_checkpoint(self, state: State, logger: Logger):
# TODO: Remove after mlflow fixes the bug that makes this necessary
import mlflow
mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''
mlflow_logger.save_model(
flavor='transformers',
transformers_model=components,
path=local_save_path,
**self.mlflow_logging_config,
)

model_saving_kwargs: Dict[str, Any] = {
'path': local_save_path
}
if composer_model.using_peft:
model_saving_kwargs['flavor'] = 'peft'
model_saving_kwargs[
'save_pretrained_dir'] = temp_save_dir
model_saving_kwargs[
'metadata'] = self.mlflow_logging_config[
'metadata']
else:
model_saving_kwargs['flavor'] = 'transformers'
model_saving_kwargs[
'transformers_model'] = components
model_saving_kwargs.update(
self.mlflow_logging_config)

mlflow_logger.save_model(**model_saving_kwargs)

# Upload the license file generated by mlflow during the model saving.
license_filename = _maybe_get_license_filename(
local_save_path)
local_save_path,
self.mlflow_logging_config['metadata'].get(
'pretrained_model_name', None))
if license_filename is not None:
mlflow_logger._mlflow_client.log_artifact(
mlflow_logger._run_id,
Expand Down
21 changes: 0 additions & 21 deletions llmfoundry/callbacks/model_gauntlet_callback.py

This file was deleted.

45 changes: 32 additions & 13 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
import logging
import os
import warnings
from functools import partial
from pathlib import Path
from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Union,
cast)
Expand Down Expand Up @@ -199,7 +200,7 @@ def _tokenize_prompt_response_formatted_example(
return tokenizer(text=prompt, text_target=response)


def _tokenize_formatted_example(
def tokenize_formatted_example(
example: Example,
tokenizer: PreTrainedTokenizerBase) -> TokenizedExample:
"""Tokenizes a formatted example using the provided tokenizer.
Expand Down Expand Up @@ -228,6 +229,33 @@ def _tokenize_formatted_example(
raise ValueError(f'Unknown conversation type {example_format=}')


def is_valid_ift_example(pad_token_id: int, max_seq_len: int,
example: Dict) -> bool:
"""Check if the example is a valid ift example.
This functions does the following check:
a. Length of input_ids should be less than max_seq_len
b. Both input_ids and labels should not be empty
c. Labels should have at least 1 non-padding token.
Args:
pad_token_id (int): The id of the padding token.
max_seq_len (int): Maximum sequence length.
example (Dict): The input example after tokenization, which has
``input_ids`` and ``labels`` fields.
Returns:
bool: Indicator of whether the input example is valid
"""
less_than_max_seq_len = len(example['input_ids']) < max_seq_len
non_empty_input = len(example['input_ids']) > 0
non_empty_labels = len(example['labels']) > 0
non_padding_response = any(
token_id != pad_token_id for token_id in example['labels'])
return (less_than_max_seq_len and non_empty_input and non_empty_labels and
non_padding_response)


class StreamingFinetuningDataset(StreamingDataset):
"""Finetuning dataset with flexible tokenization using StreamingDataset.
Expand Down Expand Up @@ -347,7 +375,7 @@ def __init__(self,
# How to process a sample
def __getitem__(self, idx: int) -> Dict[str, Any]:
sample = super().__getitem__(idx)
return _tokenize_formatted_example(sample, tokenizer=self.tokenizer)
return tokenize_formatted_example(sample, tokenizer=self.tokenizer)


class DatasetConstructor:
Expand Down Expand Up @@ -550,7 +578,7 @@ def build_from_hf(
def dataset_mapper(example: Dict):
if preprocessing_fn is not None:
example = preprocessing_fn(example)
return _tokenize_formatted_example(example, tokenizer)
return tokenize_formatted_example(example, tokenizer)

detected_cpu_count = os.cpu_count() or 1
detected_cpus_with_margin = detected_cpu_count - 8
Expand All @@ -567,17 +595,8 @@ def dataset_mapper(example: Dict):

pad_token_id = tokenizer.pad_token_id

def filter_long_or_empty_examples(example: Dict) -> bool:
less_than_max_seq_len = len(example['input_ids']) < max_seq_len
non_empty_input = len(example['input_ids']) > 0
non_empty_labels = len(example['labels']) > 0
non_padding_response = any(
token_id != pad_token_id for token_id in example['labels'])
return (less_than_max_seq_len and non_empty_input and
non_empty_labels and non_padding_response)

filtered_dataset = tokenized_dataset.filter(
filter_long_or_empty_examples,
partial(is_valid_ift_example, pad_token_id, max_seq_len),
num_proc=num_cpus_to_use,
desc='Filtering out long prompts',
)
Expand Down
Loading

0 comments on commit 16b8e32

Please sign in to comment.