Skip to content

Commit

Permalink
Merge branch 'main' into rotary_hf_imp
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML authored Oct 26, 2023
2 parents f33ed5f + c60657b commit 07eafb7
Show file tree
Hide file tree
Showing 27 changed files with 728 additions and 393 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ python inference/convert_composer_to_hf.py \
# --hf_repo_for_upload user-org/repo-name

# Evaluate the model on a subset of tasks
python eval/eval.py \
composer eval/eval.py \
eval/yamls/hf_eval.yaml \
icl_tasks=eval/yamls/copa.yaml \
model_name_or_path=mpt-125m-hf
Expand Down
11 changes: 9 additions & 2 deletions llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
import torch

try:
# Before importing any transformers models, we need to disable transformers flash attention if
# we are in an environment with flash attention version <2. Transformers hard errors on a not properly
# gated import otherwise.
import transformers

from llmfoundry import optim, utils
from llmfoundry.data import (ConcatTokensDataset,
MixtureOfDenoisersCollator, NoConcatDataset,
Expand All @@ -14,8 +19,8 @@
ComposerHFT5)
from llmfoundry.models.layers.attention import (
MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
flash_attn_fn, scaled_multihead_dot_product_attention,
triton_flash_attn_fn)
flash_attn_fn, is_flash_v1_installed,
scaled_multihead_dot_product_attention, triton_flash_attn_fn)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.ffn import (FFN_CLASS_REGISTRY, MPTMLP,
build_ffn)
Expand All @@ -24,6 +29,8 @@
MPTForCausalLM, MPTModel,
MPTPreTrainedModel)
from llmfoundry.tokenizers import TiktokenTokenizerWrapper
if is_flash_v1_installed():
transformers.utils.is_flash_attn_available = lambda: False

except ImportError as e:
try:
Expand Down
29 changes: 26 additions & 3 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import contextlib
import copy
import logging
import math
import os
import tempfile
from pathlib import Path
from typing import Optional, Union

import torch
from composer.core import Callback, Event, State, Time
from composer.core import Callback, Event, State, Time, TimeUnit
from composer.core.state import fsdp_state_dict_type_context
from composer.loggers import Logger, MLFlowLogger
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
Expand Down Expand Up @@ -83,6 +84,13 @@ def __init__(

self.huggingface_folder_name_fstr = os.path.join(
'huggingface', huggingface_folder_name)

if isinstance(save_interval, str):
save_interval = Time.from_timestring(save_interval)
if isinstance(save_interval, int):
save_interval = Time(save_interval, TimeUnit.EPOCH)

self.save_interval = save_interval
self.check_interval = create_interval_scheduler(
save_interval, include_end_of_training=True)
self.upload_to_object_store = (self.backend != '')
Expand Down Expand Up @@ -128,6 +136,21 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set(
'5GB')

def _is_last_batch(self, state: State):
elapsed_duration = state.get_elapsed_duration()
if elapsed_duration is not None and elapsed_duration >= 1.0:
return True

assert state.max_duration is not None # for pyright
# If the save interval is specified as 1dur, and the max duration is in epoch units
# we need a special case to identify we are on the last batch and should write the mlflow checkpoint
if self.save_interval.unit == TimeUnit.DURATION and self.save_interval.value == 1 and state.max_duration.unit == TimeUnit.EPOCH:
assert state.dataloader_len is not None # for pyright
return int(state.timestamp.batch) % math.ceil(
state.max_duration.value * state.dataloader_len) == 0

return False

def _save_checkpoint(self, state: State, logger: Logger):
del logger # unused

Expand Down Expand Up @@ -224,8 +247,8 @@ def _save_checkpoint(self, state: State, logger: Logger):
overwrite=self.overwrite,
)

elapsed_duration = state.get_elapsed_duration()
if self.mlflow_registered_model_name is not None and elapsed_duration is not None and elapsed_duration >= 1.0:
if self.mlflow_registered_model_name and self._is_last_batch(
state):
components = {'model': new_model_instance}
if original_tokenizer is not None:
components['tokenizer'] = original_tokenizer
Expand Down
22 changes: 20 additions & 2 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@

from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
from llmfoundry.models.layers.llama_attention_monkeypatch import \
get_llama_attention_patch_fn
from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.utils import init_empty_weights

try:
Expand Down Expand Up @@ -95,12 +94,28 @@ def __init__(self, om_model_config: Union[DictConfig,
# load the model config
trust_remote_code = om_model_config.get('trust_remote_code', True)
use_auth_token = om_model_config.get('use_auth_token', False)
use_flash_attention_2 = om_model_config.get('use_flash_attention_2',
False)
if use_flash_attention_2 and not is_flash_v2_installed():
raise ValueError(
'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. '
+ 'Please install flash_attn==2.3.2`.')

config = AutoConfig.from_pretrained(
om_model_config.pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
)

# This is not how you are supposed to set this, but transformers currently only
# supports enabling flash attention 2 when using the from_pretrained API.
# We need to support it for both from_pretrained and from_config, so we have to
# set the private attribute here. This will just skip all of transformers'
# validation logic that it is ok to use flash attention 2, so we check
# whether it is installed above, and whether the chosen config supports it here.
# https://github.com/huggingface/transformers/issues/26878
config._flash_attn_2_enabled = use_flash_attention_2

# set config overrides
for k, v in om_model_config.get('config_overrides', {}).items():
if not hasattr(config, k):
Expand Down Expand Up @@ -200,6 +215,9 @@ def __init__(self, om_model_config: Union[DictConfig,
)
from transformers.models.llama.modeling_llama import \
LlamaAttention

from llmfoundry.models.layers.llama_attention_monkeypatch import \
get_llama_attention_patch_fn
LlamaAttention.forward = get_llama_attention_patch_fn(
attention_patch_type)
model.config.use_cache = False
Expand Down
34 changes: 17 additions & 17 deletions llmfoundry/tokenizers/tiktoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self,
model_name: Optional[str] = None,
encoding_name: Optional[str] = None,
add_bos_token: bool = False,
add_eos_token: bool = False,
unk_token: Optional[str] = '<|endoftext|>',
eos_token: Optional[str] = '<|endoftext|>',
bos_token: Optional[str] = '<|endoftext|>',
Expand All @@ -36,6 +37,7 @@ def __init__(self,
encoding_name (Optional[str], optional): The name of the encoding to load from tiktoken. Defaults to None.
Either model_name or encoding_name must be set, but not both.
add_bos_token (bool, optional): Whether to add bos tokens. Defaults to False.
add_eos_token (bool, optional): Whether to add eos tokens. Defaults to False.
unk_token (Optional[str], optional): The unk token. Defaults to '<|endoftext|>'.
eos_token (Optional[str], optional): The eos token. Defaults to '<|endoftext|>'.
bos_token (Optional[str], optional): The bos token. Defaults to '<|endoftext|>'.
Expand Down Expand Up @@ -66,10 +68,12 @@ def __init__(self,
'You need to specify either model_name or encoding_name.')

self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token

super().__init__(model_name=model_name,
encoding_name=encoding_name,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
unk_token=unk_token,
eos_token=eos_token,
bos_token=bos_token,
Expand Down Expand Up @@ -151,7 +155,7 @@ def convert_ids_to_tokens(
"""
if isinstance(ids, int):
if ids in self.added_tokens_decoder:
return self.added_tokens_decoder[ids]
return str(self.added_tokens_decoder[ids])

return self._convert_id_to_token(ids)

Expand All @@ -167,7 +171,7 @@ def convert_ids_to_tokens(
if index in self.added_tokens_decoder:
tokens.append(self.encoding.decode(current_stream))
current_stream = []
tokens.append(self.added_tokens_decoder[index])
tokens.append(str(self.added_tokens_decoder[index]))
else:
current_stream.append(index)

Expand All @@ -179,17 +183,15 @@ def build_inputs_with_special_tokens(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None) -> List[int]:
if self.add_bos_token:
bos_token_ids = [self.bos_token_id]
else:
bos_token_ids = []
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
eos_token_id = [self.eos_token_id] if self.add_eos_token else []

output = bos_token_ids + token_ids_0
output = bos_token_id + token_ids_0 + eos_token_id

if token_ids_1 is None:
return output
if token_ids_1 is not None:
output = output + bos_token_id + token_ids_1 + eos_token_id

return output + bos_token_ids + token_ids_1
return output

def get_special_tokens_mask(
self,
Expand Down Expand Up @@ -221,15 +223,13 @@ def get_special_tokens_mask(
token_ids_1=token_ids_1,
already_has_special_tokens=True)

if not self.add_bos_token:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=False)
bos_token_id = [1] if self.add_bos_token else []
eos_token_id = [1] if self.add_eos_token else []

if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0))
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
return (bos_token_id + ([0] * len(token_ids_0)) + eos_token_id +
bos_token_id + ([0] * len(token_ids_1)) + eos_token_id)

def create_token_type_ids_from_sequences(
self,
Expand Down
62 changes: 55 additions & 7 deletions scripts/train/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ This README walks through pretraining and finetuning a large language model usin
#### Table of Contents
1. [Part 1: LLM Pretraining](#llmpretraining)
1. [Installation](#installation)
2. [Dataset Preparation](#datasetpreparation)
3. [How to start single and multi-node pretraining](#howtostartpretraining)
2. [Part 2: LLM Finetuning](#llmfinetuning)
1. [Dataset Preparation](#datasetpreparation)
1. [How to start single and multi-node pretraining](#howtostartpretraining)
1. [Part 2: LLM Finetuning](#llmfinetuning)
1. [Using a dataset on the HuggingFace Hub](#hfdataset)
2. [Using a local dataset](#localdataset)
3. [Using a StreamingDataset (MDS) formatted dataset locally or in an object store](#mdsdataset)
3. [FAQ: How many GPUs do I need to train a LLM?](#howmandygpus)
4. [FAQ: Optimizing Performance](#optimizingperformance)
1. [Using a local dataset](#localdataset)
1. [Using a StreamingDataset (MDS) formatted dataset locally or in an object store](#mdsdataset)
1. [Using Flash Attention](#flashattention)
1. [FAQ: How many GPUs do I need to train a LLM?](#howmandygpus)
1. [FAQ: Optimizing Performance](#optimizingperformance)

# Part 1: LLM Pretraining <a name="llmpretraining"></a>

Expand Down Expand Up @@ -332,6 +333,53 @@ train_loader:
...
```
# Using Flash Attention <a name="flashattention"></a>

Flash Attention is an optimized implementation of the attention mechanism, first introduced by [Dao et al.](https://github.com/Dao-AILab/flash-attention). There are three versions of Flash Attention that can be used with LLM Foundry: Flash Attention V1, Flash Attention V2, and a Triton implementation of Flash Attention. To start, we recommend using one of our [provided Docker images](../../README.md#mosaicml-docker-images) corresponding to the Flash Attention version you would like to use. The Triton implementation can be used with either Flash Attention V1 or V2. Next, how you specify to use Flash Attention depends on which model you are using.

For MPT, you can specify Flash Attention in your YAML like so:
```yaml
model:
name: mpt_causal_lm
...
attn_config:
# Will use either V1 or V2 depending on what is installed
# "triton" will use the Triton implementation
attn_impl: flash
...
```

If loading MPT from the HuggingFace Hub, you can specify Flash Attention in your YAML like so:
```yaml
model:
name: hf_causal_lm
pretrained_model_name_or_path: mosaicml/mpt-7b
...
config_overrides:
# Will use either V1 or V2 depending on what is installed
# "triton" will use the Triton implementation
attn_config:
attn_impl: flash
...
```

For any HuggingFace model that supports Flash Attention (e.g. Llama and Mistral), you can specify Flash Attention in your YAML like so:
```yaml
model:
name: hf_causal_lm
use_flash_attention_2: True # Will be automatically set to True if Flash Attention V2 is installed and the model supports it
...
```
HuggingFace models currently only support Flash Attention V2.

For Llama specifically, we have another option if you would like to use the Triton implementation of Flash Attention. You can specify this in your YAML like so:
```yaml
model:
name: hf_causal_lm
pretrained_model_name_or_path: meta-llama/Llama-2-7b-hf
attention_patch_type: triton
...
```

# FAQ: How many GPUs do I need to train a LLM? <a name="howmanygpus"></a>
This is a complicated question in general, but if we assume that you are using FSDP with `FULL_SHARD`,
Expand Down
4 changes: 2 additions & 2 deletions scripts/train/benchmarking/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ Our microbatching engine enables microbatch sizes that do not divde Global Batch
## A100 80GB with 1600 Gbps node-node interconnect (RoCE)

| Model | SeqLen (T) | # GPUs | GPU | MFU | HFU | Model TFLOP | MicroBatchSize | GradAccum | GlobalBatchSize | Throughput (S/s) | Throughput (T/s) | Throughput (T/s/GPU) | GlobalBatchSize (T) | Precision | MP Mode | Sharding Strategy | Activation Checkpointing | Activation CPUOffload | NumParams |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| 70b | 2048 | 64 | a100_80gb | 53.33 | 71.1 | 166 | 8 | 4 | 2048 | 12 | 26274 | 410 | 4194304 | bf16 | PURE | FULL_SHARD | True | False | 64862437376 |
| 70b | 2048 | 32 | a100_80gb | 48.56 | 64.75 | 151 | 2 | 16 | 1024 | 5 | 11962 | 373 | 2097152 | bf16 | PURE | FULL_SHARD | True | False | 64862437376 |
| 30b | 8192 | 8 | a100_80gb | 39.38 | 52.5 | 122 | 1 | 21 | 168 | 0 | 4594 | 574 | 1376256 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 30019254272 |
Expand Down Expand Up @@ -205,7 +205,7 @@ Our microbatching engine enables microbatch sizes that do not divde Global Batch
## A100 40GB with 1600 Gbps node-node interconnect (RoCE)

| Model | SeqLen (T) | # GPUs | GPU | MFU | HFU | Model TFLOP| MicroBatchSize | GradAccum | GlobalBatchSize | Throughput (S/s) | Throughput (T/s) | Throughput (T/s/GPU) | GlobalBatchSize (T) | Precision | MP Mode | Sharding Strategy | Activation Checkpointing | Activation CPUOffload | NumParams |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| 70b | 2048 | 128 | a100_40gb | 48.91 | 65.21 | 152 | 4 | 1 | 512 | 23 | 48194 | 376 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 64862437376 |
| 70b | 2048 | 64 | a100_40gb | 35.87 | 47.82 | 111 | 2 | 1 | 128 | 8 | 17672 | 276 | 262144 | bf16 | PURE | FULL_SHARD | True | False | 64862437376 |
| 30b | 2048 | 128 | a100_40gb | 52.25 | 69.66 | 163 | 6 | 1 | 768 | 54 | 110803 | 865 | 1572864 | bf16 | PURE | FULL_SHARD | True | False | 29975214080 |
Expand Down
Loading

0 comments on commit 07eafb7

Please sign in to comment.