Skip to content

Commit

Permalink
Merge branch 'main' into cli99/act-checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Nov 11, 2023
2 parents fcd3897 + e7223da commit fe02a3c
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 83 deletions.
7 changes: 6 additions & 1 deletion .github/workflows/docker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ jobs:
- name: '2.1.0_cu121_flash2'
base_image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04
dep_groups: '[gpu-flash2]'

- name: '2.1.0_cu121_aws'
base_image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04-aws
dep_groups: '[gpu]'
- name: '2.1.0_cu121_flash2_aws'
base_image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04-aws
dep_groups: '[gpu-flash2]'
steps:
- name: Maximize Build Space on Worker
uses: easimon/maximize-build-space@v4
Expand Down
60 changes: 31 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ You'll find in this repo:
Mosaic Pretrained Transformers (MPT) are GPT-style models with some special features -- Flash Attention for efficiency, ALiBi for context length extrapolation, and stability improvements to mitigate loss spikes. As part of MosaicML's Foundation series, we have open-sourced several MPT models:


| Model | Context Length | Download | Demo | Commercial use? |
|--------------------|----------------|----------------------------------------------------|------------------------------------------------------------------|-----------------|
| MPT-30B | 8192 | https://huggingface.co/mosaicml/mpt-30b | | Yes |
| MPT-30B-Instruct | 8192 | https://huggingface.co/mosaicml/mpt-30b-instruct | | Yes |
| MPT-30B-Chat | 8192 | https://huggingface.co/mosaicml/mpt-30b-chat | [Demo](https://huggingface.co/spaces/mosaicml/mpt-30b-chat) | No |
| MPT-7B | 2048 | https://huggingface.co/mosaicml/mpt-7b | | Yes |
| MPT-7B-Instruct | 2048 | https://huggingface.co/mosaicml/mpt-7b-instruct | | Yes |
| MPT-7B-Chat | 2048 | https://huggingface.co/mosaicml/mpt-7b-chat | [Demo](https://huggingface.co/spaces/mosaicml/mpt-7b-chat) | No |
| MPT-7B-StoryWriter | 65536 | https://huggingface.co/mosaicml/mpt-7b-storywriter | | Yes |
| Model | Context Length | Download | Demo | Commercial use? |
| ------------------ | -------------- | -------------------------------------------------- | ----------------------------------------------------------- | --------------- |
| MPT-30B | 8192 | https://huggingface.co/mosaicml/mpt-30b | | Yes |
| MPT-30B-Instruct | 8192 | https://huggingface.co/mosaicml/mpt-30b-instruct | | Yes |
| MPT-30B-Chat | 8192 | https://huggingface.co/mosaicml/mpt-30b-chat | [Demo](https://huggingface.co/spaces/mosaicml/mpt-30b-chat) | No |
| MPT-7B | 2048 | https://huggingface.co/mosaicml/mpt-7b | | Yes |
| MPT-7B-Instruct | 2048 | https://huggingface.co/mosaicml/mpt-7b-instruct | | Yes |
| MPT-7B-Chat | 2048 | https://huggingface.co/mosaicml/mpt-7b-chat | [Demo](https://huggingface.co/spaces/mosaicml/mpt-7b-chat) | No |
| MPT-7B-StoryWriter | 65536 | https://huggingface.co/mosaicml/mpt-7b-storywriter | | Yes |

To try out these models locally, [follow the instructions](https://github.com/mosaicml/llm-foundry/tree/main/scripts/inference#interactive-generation-with-modelgenerate) in `scripts/inference/README.md` to prompt HF models using our [hf_generate.py](https://github.com/mosaicml/llm-foundry/blob/main/scripts/inference/hf_generate.py) or [hf_chat.py](https://github.com/mosaicml/llm-foundry/blob/main/scripts/inference/hf_chat.py) scripts.

Expand Down Expand Up @@ -89,17 +89,17 @@ This codebase has been tested with PyTorch 1.13.1 and PyTorch 2.0.1 on systems w
This codebase may also work on systems with other devices, such as consumer NVIDIA cards and AMD cards, but we are not actively testing these systems.
If you have success/failure using LLM Foundry on other systems, please let us know in a Github issue and we will update the support matrix!

| Device | Torch Version | Cuda Version | Status |
|---------------------------|------------------|--------------|-------------------------------|
| A100-40GB/80GB | 1.13.1 | 11.7 | :white_check_mark: Supported |
| A100-40GB/80GB | 2.0.1 | 11.7, 11.8 | :white_check_mark: Supported |
| A100-40GB/80GB | 2.1.0 | 11.8, 12.1 | :white_check_mark: Supported |
| H100-80GB | 1.13.1 | 11.7 | :x: Not Supported |
| H100-80GB | 2.0.1 | 11.8 | :white_check_mark: Supported |
| H100-80GB | 2.1.0 | 12.1 | :white_check_mark: Supported |
| A10-24GB | 1.13.1 | 11.7 | :construction: In Progress |
| A10-24GB | 2.0.1 | 11.7, 11.8 | :construction: In Progress |
| MI250 | 2.0.1 | ROCm 5.4 | :construction: In Progress |
| Device | Torch Version | Cuda Version | Status |
| -------------- | ------------- | ------------ | ---------------------------- |
| A100-40GB/80GB | 1.13.1 | 11.7 | :white_check_mark: Supported |
| A100-40GB/80GB | 2.0.1 | 11.7, 11.8 | :white_check_mark: Supported |
| A100-40GB/80GB | 2.1.0 | 11.8, 12.1 | :white_check_mark: Supported |
| H100-80GB | 1.13.1 | 11.7 | :x: Not Supported |
| H100-80GB | 2.0.1 | 11.8 | :white_check_mark: Supported |
| H100-80GB | 2.1.0 | 12.1 | :white_check_mark: Supported |
| A10-24GB | 1.13.1 | 11.7 | :construction: In Progress |
| A10-24GB | 2.0.1 | 11.7, 11.8 | :construction: In Progress |
| MI250 | 2.0.1 | ROCm 5.4 | :construction: In Progress |

## MosaicML Docker Images
We highly recommend using our prebuilt Docker images. You can find them here: https://hub.docker.com/orgs/mosaicml/repositories.
Expand All @@ -111,15 +111,17 @@ You can select a specific commit hash such as `mosaicml/llm-foundry:1.13.1_cu117

**Please Note:** The `mosaicml/llm-foundry` images do not come with the `llm-foundry` package preinstalled, just the dependencies. You will still need to `pip install llm-foundry` either from PyPi or from source.

| Docker Image | Torch Version | Cuda Version | LLM Foundry dependencies installed? |
|-------------------------------------------------------------|----------------|--------------|-------------------------------------|
| `mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04` | 1.13.1 | 11.7 | No |
| `mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04` | 2.0.1 | 11.8 | No |
| `mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04` | 2.1.0 | 12.1 | No |
| `mosaicml/llm-foundry:1.13.1_cu117-latest` | 1.13.1 | 11.7 | Yes |
| `mosaicml/llm-foundry:2.0.1_cu118-latest` | 2.0.1 | 11.8 | Yes |
| `mosaicml/llm-foundry:2.1.0_cu121-latest` | 2.1.0 | 12.1 | Yes (flash attention v1) |
| `mosaicml/llm-foundry:2.1.0_cu121_flash2-latest` | 2.1.0 | 12.1 | Yes (flash attention v2) |
| Docker Image | Torch Version | Cuda Version | LLM Foundry dependencies installed? |
| ------------------------------------------------------ | ------------- | ----------------- | ----------------------------------- |
| `mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04` | 1.13.1 | 11.7 (Infiniband) | No |
| `mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04` | 2.0.1 | 11.8 (Infiniband) | No |
| `mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04` | 2.1.0 | 12.1 (Infiniband) | No |
| `mosaicml/llm-foundry:1.13.1_cu117-latest` | 1.13.1 | 11.7 (Infiniband) | Yes |
| `mosaicml/llm-foundry:2.0.1_cu118-latest` | 2.0.1 | 11.8 (Infiniband) | Yes |
| `mosaicml/llm-foundry:2.1.0_cu121-latest` | 2.1.0 | 12.1 (Infiniband) | Yes (flash attention v1) |
| `mosaicml/llm-foundry:2.1.0_cu121_flash2-latest` | 2.1.0 | 12.1 (Infiniband) | Yes (flash attention v2) |
| `mosaicml/llm-foundry:2.1.0_cu121_aws-latest` | 2.1.0 | 12.1 (EFA) | Yes (flash attention v1) |
| `mosaicml/llm-foundry:2.1.0_cu121_flash2_aws-latest` | 2.1.0 | 12.1 (EFA) | Yes (flash attention v2) |


# Installation
Expand Down
31 changes: 13 additions & 18 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from composer.core import Callback, Event, State, Time, TimeUnit
from composer.core.state import fsdp_state_dict_type_context
from composer.loggers import Logger, MLFlowLogger
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
from composer.models import HuggingFaceModel
from composer.utils import dist, format_name_with_dist_and_time, parse_uri
from composer.utils import (dist, format_name_with_dist_and_time,
maybe_create_remote_uploader_downloader_from_uri,
parse_uri)
from composer.utils.misc import create_interval_scheduler
from transformers import PreTrainedModel, PreTrainedTokenizerBase

Expand Down Expand Up @@ -57,8 +58,7 @@ def __init__(
mlflow_registered_model_name: Optional[str] = None,
mlflow_logging_config: Optional[dict] = None,
):
self.backend, self.bucket_name, self.save_dir_format_str = parse_uri(
save_folder)
_, _, self.save_dir_format_str = parse_uri(save_folder)
self.overwrite = overwrite
self.precision = precision
self.dtype = {
Expand Down Expand Up @@ -93,13 +93,11 @@ def __init__(
self.save_interval = save_interval
self.check_interval = create_interval_scheduler(
save_interval, include_end_of_training=True)
self.upload_to_object_store = (self.backend != '')
if self.upload_to_object_store:
self.remote_ud = RemoteUploaderDownloader(
bucket_uri=f'{self.backend}://{self.bucket_name}',
num_concurrent_uploads=4)
else:
self.remote_ud = None

self.remote_ud = maybe_create_remote_uploader_downloader_from_uri(
save_folder, loggers=[])
if self.remote_ud is not None:
self.remote_ud._num_concurrent_uploads = 4

self.last_checkpoint_batch: Optional[Time] = None
self.mlflow_loggers = []
Expand All @@ -115,7 +113,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
raise ValueError(
f'`HuggingFaceCheckpointer` is only compatible with `HuggingFaceModel`s. '
+ f'Got {type(state.model)} instead.')
if self.upload_to_object_store and self.remote_ud is not None:
if self.remote_ud is not None:
self.remote_ud.init(state, logger)
state.callbacks.append(self.remote_ud)

Expand Down Expand Up @@ -169,7 +167,7 @@ def _save_checkpoint(self, state: State, logger: Logger):
self.huggingface_folder_name_fstr), state.run_name,
state.timestamp)
dir_context_mgr = tempfile.TemporaryDirectory(
) if self.upload_to_object_store else contextlib.nullcontext(
) if self.remote_ud is not None else contextlib.nullcontext(
enter_result=save_dir)

with dir_context_mgr as temp_save_dir:
Expand Down Expand Up @@ -233,11 +231,8 @@ def _save_checkpoint(self, state: State, logger: Logger):
log.debug('Editing MPT files for HuggingFace compatibility')
edit_files_for_hf_compatibility(temp_save_dir)

if self.upload_to_object_store:
assert self.remote_ud is not None
log.info(
f'Uploading HuggingFace formatted checkpoint to {self.backend}://{self.bucket_name}/{save_dir}'
)
if self.remote_ud is not None:
log.info(f'Uploading HuggingFace formatted checkpoint')
for filename in os.listdir(temp_save_dir):
self.remote_ud.upload_file(
state=state,
Expand Down
42 changes: 19 additions & 23 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,36 +362,32 @@ def dataset_mapper(example: Dict):
num_proc=num_cpus_to_use,
desc='Tokenizing dataset',
)
prompt_length_filtered_dataset = tokenized_dataset.filter(
lambda example: len(example['input_ids']) < max_seq_len,

pad_token_id = tokenizer.pad_token_id

def filter_long_or_empty_examples(example: Dict) -> bool:
less_than_max_seq_len = len(example['input_ids']) < max_seq_len
non_empty_input = len(example['input_ids']) > 0
non_empty_labels = len(example['labels']) > 0
non_padding_response = any(
token_id != pad_token_id for token_id in example['labels'])
return (less_than_max_seq_len and non_empty_input and
non_empty_labels and non_padding_response)

filtered_dataset = tokenized_dataset.filter(
filter_long_or_empty_examples,
num_proc=num_cpus_to_use,
desc='Filtering out long prompts',
)

examples_removed = len(tokenized_dataset) - len(
prompt_length_filtered_dataset)
examples_removed = len(tokenized_dataset) - len(filtered_dataset)
if examples_removed > 0:
warnings.warn(
f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}.'
f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}, '
+
'the prompt or response was empty, or the response was all padding tokens.'
)

pad_token_id = tokenizer.pad_token_id
empty_examples_dropped_dataset = prompt_length_filtered_dataset.filter(
lambda example: len(example['input_ids']) > 0 and len(example[
'labels']) > 0 and any(token_id != pad_token_id
for token_id in example['labels']),
num_proc=num_cpus_to_use,
desc='Filtering out empty examples')

log.debug('Done tokenizing and filtering examples.')

empty_examples_removed = len(prompt_length_filtered_dataset) - len(
empty_examples_dropped_dataset)
if empty_examples_removed > 0:
warnings.warn(
f'Dropped {empty_examples_removed} examples where the prompt or response was empty, '
+ 'or the response was only padding tokens.')

# Now local rank 0 indicates to the other ranks that it is done
if dist.get_local_rank() == 0:
log.debug('Local rank 0 finished data prep')
Expand All @@ -406,7 +402,7 @@ def dataset_mapper(example: Dict):
os.remove(signal_file_path)

log.debug('All ranks finished data prep')
return empty_examples_dropped_dataset
return filtered_dataset

def build_from_streaming(self, *args: Any,
**kwargs: Any) -> StreamingFinetuningDataset:
Expand Down
19 changes: 19 additions & 0 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,14 @@ def build_tokenizer(
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed_tokenizer_setup'

if dist.is_available() and dist.is_initialized(
) and dist.get_world_size() > 1:
# Make sure the tokenizer files are downloaded and cached first by local rank 0
with dist.local_rank_zero_download_and_wait(signal_file_path):
pass

if tokenizer_name.startswith('tiktoken'):
tokenizer = TiktokenTokenizerWrapper(**tokenizer_kwargs)
else:
Expand All @@ -202,6 +210,17 @@ def build_tokenizer(
int(1e30),
)

if dist.is_available() and dist.is_initialized(
) and dist.get_world_size() > 1:
if dist.get_local_rank() == 0:
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_tokenizer_setup')

dist.barrier()

if dist.get_local_rank() == 0:
os.remove(signal_file_path)

return tokenizer


Expand Down
27 changes: 17 additions & 10 deletions llmfoundry/utils/model_download_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import os
import time
import warnings
from http import HTTPStatus
from typing import Optional
from urllib.parse import urljoin
Expand All @@ -14,6 +15,7 @@
import requests
import tenacity
from bs4 import BeautifulSoup
from requests.packages.urllib3.exceptions import InsecureRequestWarning
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_INDEX_NAME as PYTORCH_WEIGHTS_INDEX_NAME
from transformers.utils import WEIGHTS_NAME as PYTORCH_WEIGHTS_NAME
Expand Down Expand Up @@ -212,16 +214,21 @@ def download_from_cache_server(

download_start = time.time()

# Only downloads the blobs in order to avoid downloading model files twice due to the
# symlnks in the Hugging Face cache structure:
_recursive_download(
session,
cache_base_url,
# Trailing slash to indicate directory
f'{formatted_model_name}/blobs/',
save_dir,
ignore_cert=ignore_cert,
)
# Temporarily suppress noisy SSL certificate verification warnings if ignore_cert is set to True
with warnings.catch_warnings():
if ignore_cert:
warnings.simplefilter('ignore', category=InsecureRequestWarning)

# Only downloads the blobs in order to avoid downloading model files twice due to the
# symlnks in the Hugging Face cache structure:
_recursive_download(
session,
cache_base_url,
# Trailing slash to indicate directory
f'{formatted_model_name}/blobs/',
save_dir,
ignore_cert=ignore_cert,
)
download_duration = time.time() - download_start
log.info(
f'Downloaded model {model_name} from cache server in {download_duration} seconds'
Expand Down
Loading

0 comments on commit fe02a3c

Please sign in to comment.