Skip to content

Commit

Permalink
Merge branch 'mosaicml:main' into notie_embd
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley authored Nov 10, 2023
2 parents 1160b04 + e7223da commit 62295e8
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 80 deletions.
7 changes: 6 additions & 1 deletion .github/workflows/docker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ jobs:
- name: '2.1.0_cu121_flash2'
base_image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04
dep_groups: '[gpu-flash2]'

- name: '2.1.0_cu121_aws'
base_image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04-aws
dep_groups: '[gpu]'
- name: '2.1.0_cu121_flash2_aws'
base_image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04-aws
dep_groups: '[gpu-flash2]'
steps:
- name: Maximize Build Space on Worker
uses: easimon/maximize-build-space@v4
Expand Down
60 changes: 31 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ You'll find in this repo:
Mosaic Pretrained Transformers (MPT) are GPT-style models with some special features -- Flash Attention for efficiency, ALiBi for context length extrapolation, and stability improvements to mitigate loss spikes. As part of MosaicML's Foundation series, we have open-sourced several MPT models:


| Model | Context Length | Download | Demo | Commercial use? |
|--------------------|----------------|----------------------------------------------------|------------------------------------------------------------------|-----------------|
| MPT-30B | 8192 | https://huggingface.co/mosaicml/mpt-30b | | Yes |
| MPT-30B-Instruct | 8192 | https://huggingface.co/mosaicml/mpt-30b-instruct | | Yes |
| MPT-30B-Chat | 8192 | https://huggingface.co/mosaicml/mpt-30b-chat | [Demo](https://huggingface.co/spaces/mosaicml/mpt-30b-chat) | No |
| MPT-7B | 2048 | https://huggingface.co/mosaicml/mpt-7b | | Yes |
| MPT-7B-Instruct | 2048 | https://huggingface.co/mosaicml/mpt-7b-instruct | | Yes |
| MPT-7B-Chat | 2048 | https://huggingface.co/mosaicml/mpt-7b-chat | [Demo](https://huggingface.co/spaces/mosaicml/mpt-7b-chat) | No |
| MPT-7B-StoryWriter | 65536 | https://huggingface.co/mosaicml/mpt-7b-storywriter | | Yes |
| Model | Context Length | Download | Demo | Commercial use? |
| ------------------ | -------------- | -------------------------------------------------- | ----------------------------------------------------------- | --------------- |
| MPT-30B | 8192 | https://huggingface.co/mosaicml/mpt-30b | | Yes |
| MPT-30B-Instruct | 8192 | https://huggingface.co/mosaicml/mpt-30b-instruct | | Yes |
| MPT-30B-Chat | 8192 | https://huggingface.co/mosaicml/mpt-30b-chat | [Demo](https://huggingface.co/spaces/mosaicml/mpt-30b-chat) | No |
| MPT-7B | 2048 | https://huggingface.co/mosaicml/mpt-7b | | Yes |
| MPT-7B-Instruct | 2048 | https://huggingface.co/mosaicml/mpt-7b-instruct | | Yes |
| MPT-7B-Chat | 2048 | https://huggingface.co/mosaicml/mpt-7b-chat | [Demo](https://huggingface.co/spaces/mosaicml/mpt-7b-chat) | No |
| MPT-7B-StoryWriter | 65536 | https://huggingface.co/mosaicml/mpt-7b-storywriter | | Yes |

To try out these models locally, [follow the instructions](https://github.com/mosaicml/llm-foundry/tree/main/scripts/inference#interactive-generation-with-modelgenerate) in `scripts/inference/README.md` to prompt HF models using our [hf_generate.py](https://github.com/mosaicml/llm-foundry/blob/main/scripts/inference/hf_generate.py) or [hf_chat.py](https://github.com/mosaicml/llm-foundry/blob/main/scripts/inference/hf_chat.py) scripts.

Expand Down Expand Up @@ -89,17 +89,17 @@ This codebase has been tested with PyTorch 1.13.1 and PyTorch 2.0.1 on systems w
This codebase may also work on systems with other devices, such as consumer NVIDIA cards and AMD cards, but we are not actively testing these systems.
If you have success/failure using LLM Foundry on other systems, please let us know in a Github issue and we will update the support matrix!

| Device | Torch Version | Cuda Version | Status |
|---------------------------|------------------|--------------|-------------------------------|
| A100-40GB/80GB | 1.13.1 | 11.7 | :white_check_mark: Supported |
| A100-40GB/80GB | 2.0.1 | 11.7, 11.8 | :white_check_mark: Supported |
| A100-40GB/80GB | 2.1.0 | 11.8, 12.1 | :white_check_mark: Supported |
| H100-80GB | 1.13.1 | 11.7 | :x: Not Supported |
| H100-80GB | 2.0.1 | 11.8 | :white_check_mark: Supported |
| H100-80GB | 2.1.0 | 12.1 | :white_check_mark: Supported |
| A10-24GB | 1.13.1 | 11.7 | :construction: In Progress |
| A10-24GB | 2.0.1 | 11.7, 11.8 | :construction: In Progress |
| MI250 | 2.0.1 | ROCm 5.4 | :construction: In Progress |
| Device | Torch Version | Cuda Version | Status |
| -------------- | ------------- | ------------ | ---------------------------- |
| A100-40GB/80GB | 1.13.1 | 11.7 | :white_check_mark: Supported |
| A100-40GB/80GB | 2.0.1 | 11.7, 11.8 | :white_check_mark: Supported |
| A100-40GB/80GB | 2.1.0 | 11.8, 12.1 | :white_check_mark: Supported |
| H100-80GB | 1.13.1 | 11.7 | :x: Not Supported |
| H100-80GB | 2.0.1 | 11.8 | :white_check_mark: Supported |
| H100-80GB | 2.1.0 | 12.1 | :white_check_mark: Supported |
| A10-24GB | 1.13.1 | 11.7 | :construction: In Progress |
| A10-24GB | 2.0.1 | 11.7, 11.8 | :construction: In Progress |
| MI250 | 2.0.1 | ROCm 5.4 | :construction: In Progress |

## MosaicML Docker Images
We highly recommend using our prebuilt Docker images. You can find them here: https://hub.docker.com/orgs/mosaicml/repositories.
Expand All @@ -111,15 +111,17 @@ You can select a specific commit hash such as `mosaicml/llm-foundry:1.13.1_cu117

**Please Note:** The `mosaicml/llm-foundry` images do not come with the `llm-foundry` package preinstalled, just the dependencies. You will still need to `pip install llm-foundry` either from PyPi or from source.

| Docker Image | Torch Version | Cuda Version | LLM Foundry dependencies installed? |
|-------------------------------------------------------------|----------------|--------------|-------------------------------------|
| `mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04` | 1.13.1 | 11.7 | No |
| `mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04` | 2.0.1 | 11.8 | No |
| `mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04` | 2.1.0 | 12.1 | No |
| `mosaicml/llm-foundry:1.13.1_cu117-latest` | 1.13.1 | 11.7 | Yes |
| `mosaicml/llm-foundry:2.0.1_cu118-latest` | 2.0.1 | 11.8 | Yes |
| `mosaicml/llm-foundry:2.1.0_cu121-latest` | 2.1.0 | 12.1 | Yes (flash attention v1) |
| `mosaicml/llm-foundry:2.1.0_cu121_flash2-latest` | 2.1.0 | 12.1 | Yes (flash attention v2) |
| Docker Image | Torch Version | Cuda Version | LLM Foundry dependencies installed? |
| ------------------------------------------------------ | ------------- | ----------------- | ----------------------------------- |
| `mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04` | 1.13.1 | 11.7 (Infiniband) | No |
| `mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04` | 2.0.1 | 11.8 (Infiniband) | No |
| `mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04` | 2.1.0 | 12.1 (Infiniband) | No |
| `mosaicml/llm-foundry:1.13.1_cu117-latest` | 1.13.1 | 11.7 (Infiniband) | Yes |
| `mosaicml/llm-foundry:2.0.1_cu118-latest` | 2.0.1 | 11.8 (Infiniband) | Yes |
| `mosaicml/llm-foundry:2.1.0_cu121-latest` | 2.1.0 | 12.1 (Infiniband) | Yes (flash attention v1) |
| `mosaicml/llm-foundry:2.1.0_cu121_flash2-latest` | 2.1.0 | 12.1 (Infiniband) | Yes (flash attention v2) |
| `mosaicml/llm-foundry:2.1.0_cu121_aws-latest` | 2.1.0 | 12.1 (EFA) | Yes (flash attention v1) |
| `mosaicml/llm-foundry:2.1.0_cu121_flash2_aws-latest` | 2.1.0 | 12.1 (EFA) | Yes (flash attention v2) |


# Installation
Expand Down
46 changes: 17 additions & 29 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,43 +363,31 @@ def dataset_mapper(example: Dict):
desc='Tokenizing dataset',
)

def filter_long_prompts(example: Dict) -> bool:
return len(example['input_ids']) < max_seq_len
pad_token_id = tokenizer.pad_token_id

prompt_length_filtered_dataset = tokenized_dataset.filter(
filter_long_prompts,
def filter_long_or_empty_examples(example: Dict) -> bool:
less_than_max_seq_len = len(example['input_ids']) < max_seq_len
non_empty_input = len(example['input_ids']) > 0
non_empty_labels = len(example['labels']) > 0
non_padding_response = any(
token_id != pad_token_id for token_id in example['labels'])
return (less_than_max_seq_len and non_empty_input and
non_empty_labels and non_padding_response)

filtered_dataset = tokenized_dataset.filter(
filter_long_or_empty_examples,
num_proc=num_cpus_to_use,
desc='Filtering out long prompts',
)

examples_removed = len(tokenized_dataset) - len(
prompt_length_filtered_dataset)
examples_removed = len(tokenized_dataset) - len(filtered_dataset)
if examples_removed > 0:
warnings.warn(
f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}.'
f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}, '
+
'the prompt or response was empty, or the response was all padding tokens.'
)

pad_token_id = tokenizer.pad_token_id

def filter_empty_examples(example: Dict) -> bool:
return len(example['input_ids']) > 0 and len(
example['labels']) > 0 and any(
token_id != pad_token_id for token_id in example['labels'])

empty_examples_dropped_dataset = prompt_length_filtered_dataset.filter(
filter_empty_examples,
num_proc=num_cpus_to_use,
desc='Filtering out empty examples')

log.debug('Done tokenizing and filtering examples.')

empty_examples_removed = len(prompt_length_filtered_dataset) - len(
empty_examples_dropped_dataset)
if empty_examples_removed > 0:
warnings.warn(
f'Dropped {empty_examples_removed} examples where the prompt or response was empty, '
+ 'or the response was only padding tokens.')

# Now local rank 0 indicates to the other ranks that it is done
if dist.get_local_rank() == 0:
log.debug('Local rank 0 finished data prep')
Expand All @@ -414,7 +402,7 @@ def filter_empty_examples(example: Dict) -> bool:
os.remove(signal_file_path)

log.debug('All ranks finished data prep')
return empty_examples_dropped_dataset
return filtered_dataset

def build_from_streaming(self, *args: Any,
**kwargs: Any) -> StreamingFinetuningDataset:
Expand Down
22 changes: 13 additions & 9 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,11 @@ def build_tokenizer(

signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed_tokenizer_setup'

# Make sure the tokenizer files are downloaded and cached first by local rank 0
with dist.local_rank_zero_download_and_wait(signal_file_path):
pass
if dist.is_available() and dist.is_initialized(
) and dist.get_world_size() > 1:
# Make sure the tokenizer files are downloaded and cached first by local rank 0
with dist.local_rank_zero_download_and_wait(signal_file_path):
pass

if tokenizer_name.startswith('tiktoken'):
tokenizer = TiktokenTokenizerWrapper(**tokenizer_kwargs)
Expand All @@ -208,14 +210,16 @@ def build_tokenizer(
int(1e30),
)

if dist.get_local_rank() == 0:
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_tokenizer_setup')
if dist.is_available() and dist.is_initialized(
) and dist.get_world_size() > 1:
if dist.get_local_rank() == 0:
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_tokenizer_setup')

dist.barrier()
dist.barrier()

if dist.get_local_rank() == 0:
os.remove(signal_file_path)
if dist.get_local_rank() == 0:
os.remove(signal_file_path)

return tokenizer

Expand Down
27 changes: 17 additions & 10 deletions llmfoundry/utils/model_download_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import os
import time
import warnings
from http import HTTPStatus
from typing import Optional
from urllib.parse import urljoin
Expand All @@ -14,6 +15,7 @@
import requests
import tenacity
from bs4 import BeautifulSoup
from requests.packages.urllib3.exceptions import InsecureRequestWarning
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_INDEX_NAME as PYTORCH_WEIGHTS_INDEX_NAME
from transformers.utils import WEIGHTS_NAME as PYTORCH_WEIGHTS_NAME
Expand Down Expand Up @@ -212,16 +214,21 @@ def download_from_cache_server(

download_start = time.time()

# Only downloads the blobs in order to avoid downloading model files twice due to the
# symlnks in the Hugging Face cache structure:
_recursive_download(
session,
cache_base_url,
# Trailing slash to indicate directory
f'{formatted_model_name}/blobs/',
save_dir,
ignore_cert=ignore_cert,
)
# Temporarily suppress noisy SSL certificate verification warnings if ignore_cert is set to True
with warnings.catch_warnings():
if ignore_cert:
warnings.simplefilter('ignore', category=InsecureRequestWarning)

# Only downloads the blobs in order to avoid downloading model files twice due to the
# symlnks in the Hugging Face cache structure:
_recursive_download(
session,
cache_base_url,
# Trailing slash to indicate directory
f'{formatted_model_name}/blobs/',
save_dir,
ignore_cert=ignore_cert,
)
download_duration = time.time() - download_start
log.info(
f'Downloaded model {model_name} from cache server in {download_duration} seconds'
Expand Down
20 changes: 18 additions & 2 deletions scripts/misc/download_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

HF_TOKEN_ENV_VAR = 'HUGGING_FACE_HUB_TOKEN'

logging.basicConfig(format=f'%(asctime)s: %(levelname)s: %(name)s: %(message)s',
level=logging.INFO)
log = logging.getLogger(__name__)

if __name__ == '__main__':
Expand All @@ -34,7 +36,7 @@
argparser.add_argument(
'--fallback',
action='store_true',
default=False,
default=True,
help=
'Whether to fallback to downloading from Hugging Face if download from cache fails',
)
Expand All @@ -53,11 +55,25 @@
token=args.token,
ignore_cert=args.ignore_cert,
)

# A little hacky: run the Hugging Face download just to repair the symlinks in the HF cache file structure.
# This shouldn't actually download any files if the cache server download was successful, but should address
# a non-deterministic bug where the symlinks aren't repaired properly by the time the model is initialized.
log.info('Repairing Hugging Face cache symlinks')

# Hide some noisy logs that aren't important for just the symlink repair.
old_level = logging.getLogger().level
logging.getLogger().setLevel(logging.ERROR)
download_from_hf_hub(args.model,
save_dir=args.save_dir,
token=args.token)
logging.getLogger().setLevel(old_level)

except PermissionError:
log.error(f'Not authorized to download {args.model}.')
except Exception as e:
if args.fallback:
log.warn(
log.warning(
f'Failed to download {args.model} from cache server. Falling back to Hugging Face Hub. Error: {e}'
)
download_from_hf_hub(args.model,
Expand Down

0 comments on commit 62295e8

Please sign in to comment.