diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 83c9a63884..13a835356c 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -29,7 +29,12 @@ jobs: - name: '2.1.0_cu121_flash2' base_image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04 dep_groups: '[gpu-flash2]' - + - name: '2.1.0_cu121_aws' + base_image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04-aws + dep_groups: '[gpu]' + - name: '2.1.0_cu121_flash2_aws' + base_image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04-aws + dep_groups: '[gpu-flash2]' steps: - name: Maximize Build Space on Worker uses: easimon/maximize-build-space@v4 diff --git a/README.md b/README.md index 46074613e1..4a4e60e844 100644 --- a/README.md +++ b/README.md @@ -45,15 +45,15 @@ You'll find in this repo: Mosaic Pretrained Transformers (MPT) are GPT-style models with some special features -- Flash Attention for efficiency, ALiBi for context length extrapolation, and stability improvements to mitigate loss spikes. As part of MosaicML's Foundation series, we have open-sourced several MPT models: -| Model | Context Length | Download | Demo | Commercial use? | -|--------------------|----------------|----------------------------------------------------|------------------------------------------------------------------|-----------------| -| MPT-30B | 8192 | https://huggingface.co/mosaicml/mpt-30b | | Yes | -| MPT-30B-Instruct | 8192 | https://huggingface.co/mosaicml/mpt-30b-instruct | | Yes | -| MPT-30B-Chat | 8192 | https://huggingface.co/mosaicml/mpt-30b-chat | [Demo](https://huggingface.co/spaces/mosaicml/mpt-30b-chat) | No | -| MPT-7B | 2048 | https://huggingface.co/mosaicml/mpt-7b | | Yes | -| MPT-7B-Instruct | 2048 | https://huggingface.co/mosaicml/mpt-7b-instruct | | Yes | -| MPT-7B-Chat | 2048 | https://huggingface.co/mosaicml/mpt-7b-chat | [Demo](https://huggingface.co/spaces/mosaicml/mpt-7b-chat) | No | -| MPT-7B-StoryWriter | 65536 | https://huggingface.co/mosaicml/mpt-7b-storywriter | | Yes | +| Model | Context Length | Download | Demo | Commercial use? | +| ------------------ | -------------- | -------------------------------------------------- | ----------------------------------------------------------- | --------------- | +| MPT-30B | 8192 | https://huggingface.co/mosaicml/mpt-30b | | Yes | +| MPT-30B-Instruct | 8192 | https://huggingface.co/mosaicml/mpt-30b-instruct | | Yes | +| MPT-30B-Chat | 8192 | https://huggingface.co/mosaicml/mpt-30b-chat | [Demo](https://huggingface.co/spaces/mosaicml/mpt-30b-chat) | No | +| MPT-7B | 2048 | https://huggingface.co/mosaicml/mpt-7b | | Yes | +| MPT-7B-Instruct | 2048 | https://huggingface.co/mosaicml/mpt-7b-instruct | | Yes | +| MPT-7B-Chat | 2048 | https://huggingface.co/mosaicml/mpt-7b-chat | [Demo](https://huggingface.co/spaces/mosaicml/mpt-7b-chat) | No | +| MPT-7B-StoryWriter | 65536 | https://huggingface.co/mosaicml/mpt-7b-storywriter | | Yes | To try out these models locally, [follow the instructions](https://github.com/mosaicml/llm-foundry/tree/main/scripts/inference#interactive-generation-with-modelgenerate) in `scripts/inference/README.md` to prompt HF models using our [hf_generate.py](https://github.com/mosaicml/llm-foundry/blob/main/scripts/inference/hf_generate.py) or [hf_chat.py](https://github.com/mosaicml/llm-foundry/blob/main/scripts/inference/hf_chat.py) scripts. @@ -89,17 +89,17 @@ This codebase has been tested with PyTorch 1.13.1 and PyTorch 2.0.1 on systems w This codebase may also work on systems with other devices, such as consumer NVIDIA cards and AMD cards, but we are not actively testing these systems. If you have success/failure using LLM Foundry on other systems, please let us know in a Github issue and we will update the support matrix! -| Device | Torch Version | Cuda Version | Status | -|---------------------------|------------------|--------------|-------------------------------| -| A100-40GB/80GB | 1.13.1 | 11.7 | :white_check_mark: Supported | -| A100-40GB/80GB | 2.0.1 | 11.7, 11.8 | :white_check_mark: Supported | -| A100-40GB/80GB | 2.1.0 | 11.8, 12.1 | :white_check_mark: Supported | -| H100-80GB | 1.13.1 | 11.7 | :x: Not Supported | -| H100-80GB | 2.0.1 | 11.8 | :white_check_mark: Supported | -| H100-80GB | 2.1.0 | 12.1 | :white_check_mark: Supported | -| A10-24GB | 1.13.1 | 11.7 | :construction: In Progress | -| A10-24GB | 2.0.1 | 11.7, 11.8 | :construction: In Progress | -| MI250 | 2.0.1 | ROCm 5.4 | :construction: In Progress | +| Device | Torch Version | Cuda Version | Status | +| -------------- | ------------- | ------------ | ---------------------------- | +| A100-40GB/80GB | 1.13.1 | 11.7 | :white_check_mark: Supported | +| A100-40GB/80GB | 2.0.1 | 11.7, 11.8 | :white_check_mark: Supported | +| A100-40GB/80GB | 2.1.0 | 11.8, 12.1 | :white_check_mark: Supported | +| H100-80GB | 1.13.1 | 11.7 | :x: Not Supported | +| H100-80GB | 2.0.1 | 11.8 | :white_check_mark: Supported | +| H100-80GB | 2.1.0 | 12.1 | :white_check_mark: Supported | +| A10-24GB | 1.13.1 | 11.7 | :construction: In Progress | +| A10-24GB | 2.0.1 | 11.7, 11.8 | :construction: In Progress | +| MI250 | 2.0.1 | ROCm 5.4 | :construction: In Progress | ## MosaicML Docker Images We highly recommend using our prebuilt Docker images. You can find them here: https://hub.docker.com/orgs/mosaicml/repositories. @@ -111,15 +111,17 @@ You can select a specific commit hash such as `mosaicml/llm-foundry:1.13.1_cu117 **Please Note:** The `mosaicml/llm-foundry` images do not come with the `llm-foundry` package preinstalled, just the dependencies. You will still need to `pip install llm-foundry` either from PyPi or from source. -| Docker Image | Torch Version | Cuda Version | LLM Foundry dependencies installed? | -|-------------------------------------------------------------|----------------|--------------|-------------------------------------| -| `mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04` | 1.13.1 | 11.7 | No | -| `mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04` | 2.0.1 | 11.8 | No | -| `mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04` | 2.1.0 | 12.1 | No | -| `mosaicml/llm-foundry:1.13.1_cu117-latest` | 1.13.1 | 11.7 | Yes | -| `mosaicml/llm-foundry:2.0.1_cu118-latest` | 2.0.1 | 11.8 | Yes | -| `mosaicml/llm-foundry:2.1.0_cu121-latest` | 2.1.0 | 12.1 | Yes (flash attention v1) | -| `mosaicml/llm-foundry:2.1.0_cu121_flash2-latest` | 2.1.0 | 12.1 | Yes (flash attention v2) | +| Docker Image | Torch Version | Cuda Version | LLM Foundry dependencies installed? | +| ------------------------------------------------------ | ------------- | ----------------- | ----------------------------------- | +| `mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04` | 1.13.1 | 11.7 (Infiniband) | No | +| `mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04` | 2.0.1 | 11.8 (Infiniband) | No | +| `mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04` | 2.1.0 | 12.1 (Infiniband) | No | +| `mosaicml/llm-foundry:1.13.1_cu117-latest` | 1.13.1 | 11.7 (Infiniband) | Yes | +| `mosaicml/llm-foundry:2.0.1_cu118-latest` | 2.0.1 | 11.8 (Infiniband) | Yes | +| `mosaicml/llm-foundry:2.1.0_cu121-latest` | 2.1.0 | 12.1 (Infiniband) | Yes (flash attention v1) | +| `mosaicml/llm-foundry:2.1.0_cu121_flash2-latest` | 2.1.0 | 12.1 (Infiniband) | Yes (flash attention v2) | +| `mosaicml/llm-foundry:2.1.0_cu121_aws-latest` | 2.1.0 | 12.1 (EFA) | Yes (flash attention v1) | +| `mosaicml/llm-foundry:2.1.0_cu121_flash2_aws-latest` | 2.1.0 | 12.1 (EFA) | Yes (flash attention v2) | # Installation diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 4f400738e4..e02bf03693 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -14,9 +14,10 @@ from composer.core import Callback, Event, State, Time, TimeUnit from composer.core.state import fsdp_state_dict_type_context from composer.loggers import Logger, MLFlowLogger -from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader from composer.models import HuggingFaceModel -from composer.utils import dist, format_name_with_dist_and_time, parse_uri +from composer.utils import (dist, format_name_with_dist_and_time, + maybe_create_remote_uploader_downloader_from_uri, + parse_uri) from composer.utils.misc import create_interval_scheduler from transformers import PreTrainedModel, PreTrainedTokenizerBase @@ -57,8 +58,7 @@ def __init__( mlflow_registered_model_name: Optional[str] = None, mlflow_logging_config: Optional[dict] = None, ): - self.backend, self.bucket_name, self.save_dir_format_str = parse_uri( - save_folder) + _, _, self.save_dir_format_str = parse_uri(save_folder) self.overwrite = overwrite self.precision = precision self.dtype = { @@ -93,13 +93,11 @@ def __init__( self.save_interval = save_interval self.check_interval = create_interval_scheduler( save_interval, include_end_of_training=True) - self.upload_to_object_store = (self.backend != '') - if self.upload_to_object_store: - self.remote_ud = RemoteUploaderDownloader( - bucket_uri=f'{self.backend}://{self.bucket_name}', - num_concurrent_uploads=4) - else: - self.remote_ud = None + + self.remote_ud = maybe_create_remote_uploader_downloader_from_uri( + save_folder, loggers=[]) + if self.remote_ud is not None: + self.remote_ud._num_concurrent_uploads = 4 self.last_checkpoint_batch: Optional[Time] = None self.mlflow_loggers = [] @@ -115,7 +113,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: raise ValueError( f'`HuggingFaceCheckpointer` is only compatible with `HuggingFaceModel`s. ' + f'Got {type(state.model)} instead.') - if self.upload_to_object_store and self.remote_ud is not None: + if self.remote_ud is not None: self.remote_ud.init(state, logger) state.callbacks.append(self.remote_ud) @@ -169,7 +167,7 @@ def _save_checkpoint(self, state: State, logger: Logger): self.huggingface_folder_name_fstr), state.run_name, state.timestamp) dir_context_mgr = tempfile.TemporaryDirectory( - ) if self.upload_to_object_store else contextlib.nullcontext( + ) if self.remote_ud is not None else contextlib.nullcontext( enter_result=save_dir) with dir_context_mgr as temp_save_dir: @@ -233,11 +231,8 @@ def _save_checkpoint(self, state: State, logger: Logger): log.debug('Editing MPT files for HuggingFace compatibility') edit_files_for_hf_compatibility(temp_save_dir) - if self.upload_to_object_store: - assert self.remote_ud is not None - log.info( - f'Uploading HuggingFace formatted checkpoint to {self.backend}://{self.bucket_name}/{save_dir}' - ) + if self.remote_ud is not None: + log.info(f'Uploading HuggingFace formatted checkpoint') for filename in os.listdir(temp_save_dir): self.remote_ud.upload_file( state=state, diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 3673a48217..6ba6ad96c8 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -362,36 +362,32 @@ def dataset_mapper(example: Dict): num_proc=num_cpus_to_use, desc='Tokenizing dataset', ) - prompt_length_filtered_dataset = tokenized_dataset.filter( - lambda example: len(example['input_ids']) < max_seq_len, + + pad_token_id = tokenizer.pad_token_id + + def filter_long_or_empty_examples(example: Dict) -> bool: + less_than_max_seq_len = len(example['input_ids']) < max_seq_len + non_empty_input = len(example['input_ids']) > 0 + non_empty_labels = len(example['labels']) > 0 + non_padding_response = any( + token_id != pad_token_id for token_id in example['labels']) + return (less_than_max_seq_len and non_empty_input and + non_empty_labels and non_padding_response) + + filtered_dataset = tokenized_dataset.filter( + filter_long_or_empty_examples, num_proc=num_cpus_to_use, desc='Filtering out long prompts', ) - examples_removed = len(tokenized_dataset) - len( - prompt_length_filtered_dataset) + examples_removed = len(tokenized_dataset) - len(filtered_dataset) if examples_removed > 0: warnings.warn( - f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}.' + f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}, ' + + + 'the prompt or response was empty, or the response was all padding tokens.' ) - pad_token_id = tokenizer.pad_token_id - empty_examples_dropped_dataset = prompt_length_filtered_dataset.filter( - lambda example: len(example['input_ids']) > 0 and len(example[ - 'labels']) > 0 and any(token_id != pad_token_id - for token_id in example['labels']), - num_proc=num_cpus_to_use, - desc='Filtering out empty examples') - - log.debug('Done tokenizing and filtering examples.') - - empty_examples_removed = len(prompt_length_filtered_dataset) - len( - empty_examples_dropped_dataset) - if empty_examples_removed > 0: - warnings.warn( - f'Dropped {empty_examples_removed} examples where the prompt or response was empty, ' - + 'or the response was only padding tokens.') - # Now local rank 0 indicates to the other ranks that it is done if dist.get_local_rank() == 0: log.debug('Local rank 0 finished data prep') @@ -406,7 +402,7 @@ def dataset_mapper(example: Dict): os.remove(signal_file_path) log.debug('All ranks finished data prep') - return empty_examples_dropped_dataset + return filtered_dataset def build_from_streaming(self, *args: Any, **kwargs: Any) -> StreamingFinetuningDataset: diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index f027afb0ce..142e714b55 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -188,6 +188,14 @@ def build_tokenizer( os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1' os.environ['TOKENIZERS_PARALLELISM'] = 'false' + signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed_tokenizer_setup' + + if dist.is_available() and dist.is_initialized( + ) and dist.get_world_size() > 1: + # Make sure the tokenizer files are downloaded and cached first by local rank 0 + with dist.local_rank_zero_download_and_wait(signal_file_path): + pass + if tokenizer_name.startswith('tiktoken'): tokenizer = TiktokenTokenizerWrapper(**tokenizer_kwargs) else: @@ -202,6 +210,17 @@ def build_tokenizer( int(1e30), ) + if dist.is_available() and dist.is_initialized( + ) and dist.get_world_size() > 1: + if dist.get_local_rank() == 0: + with open(signal_file_path, 'wb') as f: + f.write(b'local_rank0_completed_tokenizer_setup') + + dist.barrier() + + if dist.get_local_rank() == 0: + os.remove(signal_file_path) + return tokenizer diff --git a/llmfoundry/utils/model_download_utils.py b/llmfoundry/utils/model_download_utils.py index d268cb78b7..2104455e0f 100644 --- a/llmfoundry/utils/model_download_utils.py +++ b/llmfoundry/utils/model_download_utils.py @@ -6,6 +6,7 @@ import logging import os import time +import warnings from http import HTTPStatus from typing import Optional from urllib.parse import urljoin @@ -14,6 +15,7 @@ import requests import tenacity from bs4 import BeautifulSoup +from requests.packages.urllib3.exceptions import InsecureRequestWarning from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME from transformers.utils import WEIGHTS_INDEX_NAME as PYTORCH_WEIGHTS_INDEX_NAME from transformers.utils import WEIGHTS_NAME as PYTORCH_WEIGHTS_NAME @@ -212,16 +214,21 @@ def download_from_cache_server( download_start = time.time() - # Only downloads the blobs in order to avoid downloading model files twice due to the - # symlnks in the Hugging Face cache structure: - _recursive_download( - session, - cache_base_url, - # Trailing slash to indicate directory - f'{formatted_model_name}/blobs/', - save_dir, - ignore_cert=ignore_cert, - ) + # Temporarily suppress noisy SSL certificate verification warnings if ignore_cert is set to True + with warnings.catch_warnings(): + if ignore_cert: + warnings.simplefilter('ignore', category=InsecureRequestWarning) + + # Only downloads the blobs in order to avoid downloading model files twice due to the + # symlnks in the Hugging Face cache structure: + _recursive_download( + session, + cache_base_url, + # Trailing slash to indicate directory + f'{formatted_model_name}/blobs/', + save_dir, + ignore_cert=ignore_cert, + ) download_duration = time.time() - download_start log.info( f'Downloaded model {model_name} from cache server in {download_duration} seconds' diff --git a/scripts/misc/download_hf_model.py b/scripts/misc/download_hf_model.py index 6465a552c2..58c3445e7d 100644 --- a/scripts/misc/download_hf_model.py +++ b/scripts/misc/download_hf_model.py @@ -14,6 +14,8 @@ HF_TOKEN_ENV_VAR = 'HUGGING_FACE_HUB_TOKEN' +logging.basicConfig(format=f'%(asctime)s: %(levelname)s: %(name)s: %(message)s', + level=logging.INFO) log = logging.getLogger(__name__) if __name__ == '__main__': @@ -34,7 +36,7 @@ argparser.add_argument( '--fallback', action='store_true', - default=False, + default=True, help= 'Whether to fallback to downloading from Hugging Face if download from cache fails', ) @@ -53,11 +55,25 @@ token=args.token, ignore_cert=args.ignore_cert, ) + + # A little hacky: run the Hugging Face download just to repair the symlinks in the HF cache file structure. + # This shouldn't actually download any files if the cache server download was successful, but should address + # a non-deterministic bug where the symlinks aren't repaired properly by the time the model is initialized. + log.info('Repairing Hugging Face cache symlinks') + + # Hide some noisy logs that aren't important for just the symlink repair. + old_level = logging.getLogger().level + logging.getLogger().setLevel(logging.ERROR) + download_from_hf_hub(args.model, + save_dir=args.save_dir, + token=args.token) + logging.getLogger().setLevel(old_level) + except PermissionError: log.error(f'Not authorized to download {args.model}.') except Exception as e: if args.fallback: - log.warn( + log.warning( f'Failed to download {args.model} from cache server. Falling back to Hugging Face Hub. Error: {e}' ) download_from_hf_hub(args.model,