From f0fd749ca0aa9eb1510e27805458891482eabac6 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Thu, 4 Jan 2024 18:36:54 -0800 Subject: [PATCH 01/31] Add safe_load option to restrict HF dataset downloads to allowed file types (#798) --- .gitignore | 1 + llmfoundry/data/finetuning/dataloader.py | 163 +++++++++++------------ llmfoundry/data/finetuning/tasks.py | 84 +++++++++--- tests/data/test_dataloader.py | 59 +++++++- tests/data_utils.py | 8 +- 5 files changed, 203 insertions(+), 112 deletions(-) diff --git a/.gitignore b/.gitignore index 989fb3af0c..d041a25c22 100644 --- a/.gitignore +++ b/.gitignore @@ -155,3 +155,4 @@ notebooks/ **/*.pt **/mlruns/* **/tokenizer-save-dir-*/** +**/.downloaded_finetuning/ diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 7a29d1dfed..97725ce78c 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -4,7 +4,6 @@ import os from typing import Tuple, Union -import datasets as hf_datasets import torch from composer.core.data_spec import DataSpec from composer.utils import dist, get_file, parse_uri @@ -13,7 +12,9 @@ from transformers import PreTrainedTokenizerBase from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator -from llmfoundry.data.finetuning.tasks import dataset_constructor +from llmfoundry.data.finetuning.tasks import (DOWNLOADED_FT_DATASETS_DIRPATH, + SUPPORTED_EXTENSIONS, + dataset_constructor) from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio from llmfoundry.data.text_data import get_tokens_per_batch_func @@ -122,8 +123,13 @@ def build_finetuning_dataloader(cfg: DictConfig, if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token + collate_fn, dataloader_batch_size = _build_collate_fn( + cfg, tokenizer, device_batch_size) + dataset = None # for pyright + sampler = None if cfg.dataset.get('remote') is not None: + # Build streaming dataloader dataset = dataset_constructor.build_from_streaming( tokenizer=tokenizer, local=cfg.dataset.local, @@ -148,40 +154,45 @@ def build_finetuning_dataloader(cfg: DictConfig, batching_method=cfg.dataset.get('batching_method', 'random'), ) - collate_fn, dataloader_batch_size = _build_collate_fn( - cfg, tokenizer, device_batch_size) - - dl = DataLoader( - dataset, - collate_fn=collate_fn, - batch_size=dataloader_batch_size, - drop_last=cfg.drop_last, - num_workers=cfg.num_workers, - pin_memory=cfg.get('pin_memory', True), - prefetch_factor=cfg.get('prefetch_factor', 2), - persistent_workers=cfg.get('persistent_workers', True), - timeout=cfg.get('timeout', 0), - ) - else: - backend, _, _ = parse_uri(cfg.dataset.hf_name) + # Build HF dataloader + dataset_name_or_path = cfg.dataset.hf_name + split = cfg.dataset.get('split') + + # If dataset is a remote path, download it first. + backend, _, _ = parse_uri(dataset_name_or_path) if backend not in ['', None]: - if cfg.dataset.get('split') is None: + if split is None: raise ValueError( 'When using a HuggingFace dataset from a URL, you must set the ' + \ '`split` key in the dataset config.' ) - dataset = _build_hf_dataset_from_remote(cfg, tokenizer) + # HF datasets does not support a split with dashes, so we replace dashes + # with underscores. + split = split.replace('-', '_') + dataset_name_or_path = _download_remote_hf_dataset( + remote_path=dataset_name_or_path, split=split) + + # Get the preprocessing function. + proto_preprocessing_fn = cfg.dataset.get('preprocessing_fn') + if isinstance(proto_preprocessing_fn, (dict, DictConfig)): + preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_dict( + dict(proto_preprocessing_fn)) else: - dataset = dataset_constructor.build_from_hf( - cfg.dataset, - max_seq_len=cfg.dataset.max_seq_len, - tokenizer=tokenizer, - ) + preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_str( + proto_preprocessing_fn, dataset_name_or_path) - collate_fn, dataloader_batch_size = _build_collate_fn( - cfg, tokenizer, device_batch_size) + # Build dataset from HF. + dataset = dataset_constructor.build_from_hf( + dataset_name=dataset_name_or_path, + split=split, + safe_load=cfg.dataset.get('safe_load', False), + max_seq_len=cfg.dataset.max_seq_len, + preprocessing_fn=preprocessing_fn, + tokenizer=tokenizer, + hf_kwargs=cfg.dataset.get('kwargs', {})) + # Ensure dataset is large enough. if cfg.drop_last: world_size = dist.get_world_size() minimum_dataset_size = world_size * dataloader_batch_size @@ -189,7 +200,7 @@ def build_finetuning_dataloader(cfg: DictConfig, full_dataset_size = len(dataset) if full_dataset_size < minimum_dataset_size: raise ValueError( - f'Your dataset (name={cfg.dataset.hf_name}, split={cfg.dataset.split}) ' + f'Your dataset (name={cfg.dataset.hf_name}, split={split}) ' + f'has {full_dataset_size} samples, but your minimum batch size ' + @@ -199,22 +210,24 @@ def build_finetuning_dataloader(cfg: DictConfig, + f'of samples in your dataset to at least {minimum_dataset_size}.' ) - - assert dataset is not None - dl = DataLoader( - dataset, - collate_fn=collate_fn, - batch_size=dataloader_batch_size, - drop_last=cfg.drop_last, - sampler=dist.get_sampler(dataset, - drop_last=cfg.drop_last, - shuffle=cfg.dataset.shuffle), - num_workers=cfg.num_workers, - pin_memory=cfg.get('pin_memory', True), - prefetch_factor=cfg.get('prefetch_factor', 2), - persistent_workers=cfg.get('persistent_workers', True), - timeout=cfg.get('timeout', 0), - ) + # Initialize sampler. + sampler = dist.get_sampler(dataset, + drop_last=cfg.drop_last, + shuffle=cfg.dataset.shuffle) + + assert dataset is not None # for pyright + dl = DataLoader( + dataset, + collate_fn=collate_fn, + batch_size=dataloader_batch_size, + drop_last=cfg.drop_last, + sampler=sampler, + num_workers=cfg.num_workers, + pin_memory=cfg.get('pin_memory', True), + prefetch_factor=cfg.get('prefetch_factor', 2), + persistent_workers=cfg.get('persistent_workers', True), + timeout=cfg.get('timeout', 0), + ) token_counting_func = get_tokens_per_batch_func() @@ -250,7 +263,7 @@ def _validate_config(dataset_cfg: DictConfig) -> None: ) elif dataset_cfg.get('remote') is not None: # Using the streaming dataset codepath - illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn'] + illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn', 'safe_load'] discovered_illegal_keys = [] for key in illegal_keys: if dataset_cfg.get(key) is not None: @@ -275,11 +288,8 @@ def _validate_config(dataset_cfg: DictConfig) -> None: ) -def _build_hf_dataset_from_remote( - cfg: DictConfig, tokenizer: PreTrainedTokenizerBase -) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset, - hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]: - """Builds a dataset from a remote object store. +def _download_remote_hf_dataset(remote_path: str, split: str) -> str: + """Downloads a dataset from a remote object store. This function supports 'jsonl', 'csv', and 'parquet' file formats for the dataset. It will attempt to download the dataset, then once it is downloaded, convert it into HuggingFace ``datasets`` format, and then return this @@ -290,38 +300,26 @@ def _build_hf_dataset_from_remote( completed, the function removes the signal file. Args: - cfg (DictConfig): The configuration dictionary containing the necessary parameters to load the dataset. - This includes: - - dataset.hf_name: The path of the HuggingFace dataset to download. - - dataset.split: The dataset split to download (e.g., 'train', 'validation', 'test'). - - dataset.max_seq_len: The maximum sequence length for tokenizing the dataset. - - tokenizer (Tokenizer): The tokenizer to be used to tokenize the dataset. + hf_name (str): The path of the HuggingFace dataset to download. + split (str): The dataset split to download (e.g., 'train', 'validation', 'test'). Returns: - Dataset: A HuggingFace dataset built from the remote file, prepared and tokenized for fine-tuning the model. + A local directory path where the dataset files are stored. Raises: FileNotFoundError: Raised if the dataset file cannot be found with any of the supported extensions. """ - supported_extensions = ['jsonl', 'csv', 'parquet'] - # HF datasets does not support a split with dashes, so we replace dashes - # with underscores in the destination split. - destination_split = cfg.dataset.split.replace('-', '_') finetune_dir = os.path.join( - os.path.dirname( - os.path.dirname(os.path.dirname(os.path.realpath(__file__)))), - 'downloaded_finetuning', - destination_split if destination_split != 'data' else 'data_not', + DOWNLOADED_FT_DATASETS_DIRPATH, + split if split != 'data' else 'data_not', ) os.makedirs(finetune_dir, exist_ok=True) - for extension in supported_extensions: - name = f'{cfg.dataset.hf_name.strip("/")}/{cfg.dataset.split}.{extension}' + for extension in SUPPORTED_EXTENSIONS: + name = f'{remote_path.strip("/")}/{split}{extension}' destination = str( os.path.abspath( - os.path.join( - finetune_dir, 'data', - f'{destination_split}-00000-of-00001.{extension}'))) + os.path.join(finetune_dir, 'data', + f'{split}-00000-of-00001{extension}'))) # Since we don't know exactly what the extension will be, since it is one of a list # use a signal file to wait for instead of the desired file @@ -331,14 +329,14 @@ def _build_hf_dataset_from_remote( try: get_file(path=name, destination=destination, overwrite=True) except FileNotFoundError as e: - if extension == supported_extensions[-1]: + if extension == SUPPORTED_EXTENSIONS[-1]: files_searched = [ - f'{cfg.dataset.hf_name}/{cfg.dataset.split}.{ext}' - for ext in supported_extensions + f'{cfg.dataset.hf_name}/{cfg.dataset.split}{ext}' + for ext in SUPPORTED_EXTENSIONS ] raise FileNotFoundError( f'Could not find a file with any of ' + \ - f'the supported extensions: {supported_extensions}\n' + \ + f'the supported extensions: {SUPPORTED_EXTENSIONS}\n' + \ f'at {files_searched}' ) from e else: @@ -350,25 +348,18 @@ def _build_hf_dataset_from_remote( with open(signal_file_path, 'wb') as f: f.write(b'local_rank0_completed_download') - # Avoid the collective call until the local rank zero has finished trying to download the checkpoint + # Avoid the collective call until the local rank zero has finished trying to download the dataset # so that we don't timeout for large downloads. This syncs all processes on the node with dist.local_rank_zero_download_and_wait(signal_file_path): - # Then, wait to ensure every node has finished downloading the checkpoint + # Then, wait to ensure every node has finished trying to download the dataset dist.barrier() # clean up signal file if dist.get_local_rank() == 0: os.remove(signal_file_path) dist.barrier() - - cfg.dataset.hf_name = finetune_dir - log.info(cfg.dataset) - dataset = dataset_constructor.build_from_hf( - cfg.dataset, - max_seq_len=cfg.dataset.max_seq_len, - tokenizer=tokenizer, - ) - return dataset + break + return finetune_dir def _build_collate_fn( diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 21c3558b2d..e61d138c41 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -35,11 +35,12 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: import logging import os import warnings +from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union import datasets as hf_datasets +import huggingface_hub as hf_hub from composer.utils import dist -from omegaconf import DictConfig from streaming import StreamingDataset from transformers import PreTrainedTokenizerBase @@ -51,6 +52,22 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: _ALLOWED_RESPONSE_KEYS = {'response', 'completion'} _ALLOWED_PROMPT_KEYS = {'prompt'} +DOWNLOADED_FT_DATASETS_DIRPATH = os.path.abspath( + os.path.join(os.path.realpath(__file__), os.pardir, os.pardir, os.pardir, + '.downloaded_finetuning')) +SUPPORTED_EXTENSIONS = ['.csv', '.jsonl', '.parquet'] + + +def _is_empty_or_nonexistent(dirpath: str) -> bool: + """Check if a directory is empty or non-existent. + + Args: + dirpath (str): Directory path to check. + + Returns + True if directory is empty or non-existent. False otherwise. + """ + return not os.path.isdir(dirpath) or len(os.listdir(dirpath)) == 0 def _tokenize_formatted_example( @@ -241,8 +258,9 @@ def print_registered_tasks(self) -> None: log.info('\n'.join(tasks)) def get_preprocessing_fn_from_dict( - self, mapping: Union[Dict, DictConfig] - ) -> Callable[[Dict[str, Any]], Dict[str, str]]: + self, + mapping: Dict[str, + str]) -> Callable[[Dict[str, Any]], Dict[str, str]]: """Get a preprocessing function from a dictionary. The dictionary maps column names in the dataset to "prompt" and "response". @@ -327,8 +345,10 @@ def get_preprocessing_fn_from_str( return preprocessing_fn def build_from_hf( - self, cfg: DictConfig, max_seq_len: int, - tokenizer: PreTrainedTokenizerBase + self, dataset_name: str, split: Optional[str], safe_load: bool, + max_seq_len: int, preprocessing_fn: Optional[Callable[[dict[str, Any]], + dict[str, str]]], + tokenizer: PreTrainedTokenizerBase, hf_kwargs: Dict[str, Any] ) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset, hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]: """Load a HuggingFace Datasets, preprocess, and tokenize. @@ -343,20 +363,6 @@ def build_from_hf( Returns: Dataset: The tokenized dataset. """ - dataset_name = cfg.hf_name - # HF datasets does not support a split with dashes,so we replace split - # dashes with underscore. - split = cfg.split.replace('-', '_') - kwargs = cfg.get('hf_kwargs', {}) - proto_preprocessing_fn = cfg.get('preprocessing_fn') - if isinstance(proto_preprocessing_fn, dict) or isinstance( - proto_preprocessing_fn, DictConfig): - preprocessing_fn = self.get_preprocessing_fn_from_dict( - proto_preprocessing_fn) - else: - preprocessing_fn = self.get_preprocessing_fn_from_str( - proto_preprocessing_fn, dataset_name) - signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_data_prep_completed' # Non local rank 0 ranks will wait here for local rank 0 to finish the data processing. @@ -379,9 +385,47 @@ def build_from_hf( error: Optional[Exception] = None filtered_dataset = None try: + if safe_load: + if not os.path.isdir(dataset_name): + # dataset_name is not a local dir path, download if needed. + local_dataset_dir = os.path.join( + DOWNLOADED_FT_DATASETS_DIRPATH, dataset_name) + + if _is_empty_or_nonexistent(dirpath=local_dataset_dir): + # Safely load a dataset from HF Hub with restricted file types. + hf_hub.snapshot_download( + dataset_name, + repo_type='dataset', + allow_patterns=[ + '*' + ext for ext in SUPPORTED_EXTENSIONS + ], + token=hf_kwargs.get('token', None), + revision=hf_kwargs.get('revision', None), + local_dir_use_symlinks=False, + local_dir=local_dataset_dir) + if _is_empty_or_nonexistent(dirpath=local_dataset_dir): + raise FileNotFoundError( + f'safe_load is set to True. No data files with safe extensions {SUPPORTED_EXTENSIONS} ' + + f'found for dataset {dataset_name}. ') + # Set dataset_name to the downloaded location. + dataset_name = local_dataset_dir + + # dataset_name is a local dir path. Use the abspath to prevent confusion. + dataset_name = os.path.abspath(dataset_name) + + # Ensure that the local dir contains only allowed file types. + dataset_files = [ + f for _, _, files in os.walk(dataset_name) for f in files + ] + if not all( + Path(f).suffix in SUPPORTED_EXTENSIONS + for f in dataset_files): + raise ValueError( + f'Dataset at local path {dataset_name} contains invalid file types. ' + + f'Allowed file types are: {SUPPORTED_EXTENSIONS}') dataset = hf_datasets.load_dataset(dataset_name, split=split, - **kwargs) + **hf_kwargs) def dataset_mapper(example: Dict): if preprocessing_fn is not None: diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index bf818347a0..7f99eeda25 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -7,7 +7,9 @@ import shutil import tempfile from argparse import Namespace -from typing import Literal, Optional, Union +from contextlib import nullcontext as does_not_raise +from pathlib import Path +from typing import ContextManager, Literal, Optional, Union from unittest.mock import MagicMock import pytest @@ -23,6 +25,8 @@ from llmfoundry.data import build_dataloader from llmfoundry.data.finetuning.tasks import (_ALLOWED_PROMPT_KEYS, _ALLOWED_RESPONSE_KEYS, + DOWNLOADED_FT_DATASETS_DIRPATH, + SUPPORTED_EXTENSIONS, _tokenize_formatted_example) from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper, build_text_dataloader, @@ -306,6 +310,51 @@ def test_finetuning_dataloader(decoder_only_format: bool, break +@pytest.mark.parametrize( + 'hf_name, hf_revision, expectation', + [('HuggingFaceH4/databricks_dolly_15k', None, does_not_raise()), + ('squad', '5fe18c', pytest.raises(FileNotFoundError))]) +def test_finetuning_dataloader_safe_load(hf_name: str, + hf_revision: Optional[str], + expectation: ContextManager): + cfg = DictConfig({ + 'name': 'finetuning', + 'dataset': { + 'hf_name': hf_name, + 'split': 'train', + 'max_seq_len': 8, + 'decoder_only_format': True, + 'shuffle': True, + 'safe_load': True, + 'kwargs': { + 'revision': hf_revision + } + }, + 'drop_last': False, + 'num_workers': 0, + 'pin_memory': False, + 'prefetch_factor': None, + 'persistent_workers': False, + 'timeout': 0 + }) + + tokenizer = build_tokenizer('gpt2', {}) + + with expectation: + _ = build_finetuning_dataloader(cfg, tokenizer, 1) + + # If no raised errors, we should expect downloaded files with only safe file types. + if expectation == does_not_raise(): + download_dir = os.path.join(DOWNLOADED_FT_DATASETS_DIRPATH, hf_name) + downloaded_files = [ + file for _, _, files in os.walk(download_dir) for file in files + ] + assert len(downloaded_files) > 0 + assert all( + Path(file).suffix in SUPPORTED_EXTENSIONS + for file in downloaded_files) + + @pytest.mark.world_size(2) @pytest.mark.gpu @pytest.mark.parametrize('dataset_size', [4, 8]) @@ -441,12 +490,16 @@ def test_finetuning_dataloader_custom_split(tmp_path: pathlib.Path, split: str): def mock_get_file(path: str, destination: str, overwrite: bool = False): - make_tiny_ft_dataset(path=destination, size=16) + if Path(destination).suffix == '.jsonl': + make_tiny_ft_dataset(path=destination, size=16) + else: + raise FileNotFoundError( + f'Test error in mock_get_file. {path} does not exist.') @pytest.mark.parametrize('split', ['train', 'custom', 'custom-dash', 'data']) def test_finetuning_dataloader_custom_split_remote( - tmp_path: pathlib.Path, split: str, monkeypatch: pytest.MonkeyPatch): + split: str, monkeypatch: pytest.MonkeyPatch): tokenizer_name = 'gpt2' max_seq_len = 2048 diff --git a/tests/data_utils.py b/tests/data_utils.py index a0ad6bcd13..0afe3fa1a1 100644 --- a/tests/data_utils.py +++ b/tests/data_utils.py @@ -3,9 +3,9 @@ import json import os -import pathlib import shutil from argparse import Namespace +from pathlib import Path from typing import Optional from omegaconf import DictConfig @@ -26,6 +26,8 @@ def make_tiny_ft_dataset( start_token: Optional[str] = None, end_token: Optional[str] = None, ): + if Path(path).suffix != '.jsonl': + raise ValueError(f'Path {path} must be a jsonl file.') good_sample = {'prompt': 'hello', 'response': 'goodbye'} samples = [good_sample] * size if add_bad_data_dropped: @@ -77,7 +79,7 @@ def make_tiny_ft_dataset( _f.write('\n') -def create_c4_dataset_xxsmall(path: pathlib.Path) -> str: +def create_c4_dataset_xxsmall(path: Path) -> str: """Creates a small mocked version of the C4 dataset.""" c4_dir = os.path.join(path, f'my-copy-c4') downloaded_split = 'val_xxsmall' # very fast to convert @@ -109,7 +111,7 @@ def create_c4_dataset_xxsmall(path: pathlib.Path) -> str: return c4_dir -def create_arxiv_dataset(path: pathlib.Path) -> str: +def create_arxiv_dataset(path: Path) -> str: """Creates an arxiv dataset.""" arxiv_dir = os.path.join(path, f'my-copy-arxiv') downloaded_split = 'train' From d991f37dda4b46af95728dfc7d1b786f8ddd8017 Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Fri, 5 Jan 2024 13:40:38 -0800 Subject: [PATCH 02/31] Adding support for alibi when using flash attention (#820) * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * Update llmfoundry/models/layers/attention.py Co-authored-by: Aaron Gokaslan * .. * .. * .. * .. * .. * Update tests/models/layers/test_flash_attn.py Co-authored-by: Irene Dea * .. * .. * Update tests/models/layers/test_flash_attn.py Co-authored-by: Irene Dea * .. --------- Co-authored-by: Shashank Rajput Co-authored-by: Aaron Gokaslan Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> Co-authored-by: Irene Dea --- llmfoundry/models/layers/attention.py | 25 +++- llmfoundry/models/layers/blocks.py | 2 + llmfoundry/models/mpt/configuration_mpt.py | 11 +- llmfoundry/models/mpt/modeling_mpt.py | 11 +- setup.py | 2 +- tests/models/layers/test_flash_attn.py | 113 +++++++++++++++++- .../models/layers/test_flash_triton_torch.py | 29 +++-- tests/models/test_model.py | 35 +++--- 8 files changed, 192 insertions(+), 36 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 2a90bf2f80..0fb6c0a042 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -39,6 +39,11 @@ def is_transformers_version_gte(hf_version: str) -> bool: return version.parse(transformers.__version__) >= version.parse(hf_version) +def check_alibi_support(attention_impl: str) -> bool: + return attention_impl != 'flash' or is_flash_v2_installed( + v2_version='v2.4.2') + + # Before importing any transformers models, we need to disable transformers flash attention if # we are in an environment with flash attention version <2. Transformers hard errors on a not properly # gated import otherwise. @@ -226,6 +231,7 @@ def flash_attn_fn( attention_mask_in_length: Optional[torch.Tensor] = None, should_repeat_kv_for_gqa: Optional[bool] = True, sliding_window_size: int = -1, + alibi_slopes: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: try: @@ -334,6 +340,12 @@ def flash_attn_fn( causal=reset_is_causal, return_attn_probs=needs_weights) elif is_flash_v2_installed(): + alibi_kwargs = {} + if check_alibi_support('flash'): + alibi_kwargs = {'alibi_slopes': alibi_slopes} + elif alibi_slopes is not None: + raise ValueError( + 'alibi_slopes is only supported for flash-attn>=2.4.2') output_unpad = flash_attn_interface.flash_attn_varlen_func( q=query_unpad, k=key_unpad, @@ -346,10 +358,11 @@ def flash_attn_fn( softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights, - window_size=(sliding_window_size, sliding_window_size)) + window_size=(sliding_window_size, sliding_window_size), + **alibi_kwargs) else: raise RuntimeError( - 'flash-attn==1.0.9 or flash-attn==2.3.6 is required.') + 'flash-attn==1.0.9 or flash-attn==2.4.2 is required.') output = bert_padding.pad_input( rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, @@ -587,6 +600,7 @@ def forward( is_causal: bool = True, needs_weights: bool = False, attention_mask_in_length: Optional[torch.Tensor] = None, + alibi_slopes: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: qkv = self.Wqkv(x) @@ -656,6 +670,7 @@ def forward( 'attention_mask_in_length': attention_mask_in_length, 'should_repeat_kv_for_gqa': not is_flash_v2_installed(), 'sliding_window_size': self.sliding_window_size, + 'alibi_slopes': alibi_slopes, } context, attn_weights, past_key_value = self.attn_fn( @@ -805,7 +820,8 @@ def build_attn_bias( def gen_slopes(n_heads: int, alibi_bias_max: int = 8, - device: Optional[torch.device] = None) -> torch.Tensor: + device: Optional[torch.device] = None, + return_1d: bool = False) -> torch.Tensor: _n_heads = 2**math.ceil(math.log2(n_heads)) m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device) m = m.mul(alibi_bias_max / _n_heads) @@ -816,7 +832,8 @@ def gen_slopes(n_heads: int, # Huggingface and FasterTransformer calculate slopes normally, # then return this strided concatenation of slopes slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads] - + if return_1d: + return slopes return slopes.view(1, n_heads, 1, 1) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index c077ccb535..e5032998dc 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -123,6 +123,7 @@ def forward( is_causal: bool = True, output_attentions: bool = False, attention_mask_in_length: Optional[torch.Tensor] = None, + alibi_slopes: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ torch.Tensor, torch.Tensor]]]: a = self.norm_1(x) @@ -135,6 +136,7 @@ def forward( is_causal=is_causal, needs_weights=output_attentions, attention_mask_in_length=attention_mask_in_length, + alibi_slopes=alibi_slopes, ) x = x + self.resid_attn_dropout(b) m = x diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index ae4754108c..5474529277 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -8,7 +8,8 @@ from transformers import PretrainedConfig -from llmfoundry.models.layers.attention import is_flash_v2_installed +from llmfoundry.models.layers.attention import (check_alibi_support, + is_flash_v2_installed) from llmfoundry.models.layers.blocks import attn_config_defaults # NOTE: All utils are imported directly even if unused so that @@ -220,11 +221,11 @@ def _validate_config(self) -> None: 'attn_impl'] not in ['torch', 'triton']: raise NotImplementedError( 'prefix_lm only implemented with torch and triton attention.') - if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in [ - 'torch', 'triton' - ]: + if self.attn_config['alibi'] and not check_alibi_support( + self.attn_config['attn_impl']): raise NotImplementedError( - 'alibi only implemented with torch and triton attention.') + 'alibi only implemented with torch, triton, and flash (v2.4.2 or higher) attention.' + ) if self.attn_config['attn_uses_sequence_id'] and not ( self.attn_config['attn_impl'] in ['torch', 'triton'] or (self.attn_config['attn_impl'] == 'flash' and diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 4c80b10ed9..e2274ffd6c 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -47,7 +47,7 @@ from llmfoundry.models.layers.attention import (ATTN_CLASS_REGISTRY, attn_bias_shape, - build_attn_bias) + build_attn_bias, gen_slopes) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY @@ -607,6 +607,14 @@ def forward( attn_uses_sequence_id=self.attn_uses_sequence_id, attn_impl=self.attn_impl, attention_mask=attention_mask) + + alibi_slopes = None # alibi_slopes will only be used by flash attention for ALiBi + if self.alibi and self.attn_impl == 'flash': + alibi_slopes = gen_slopes(n_heads=self.config.n_heads, + alibi_bias_max=self.alibi_bias_max, + device=x.device, + return_1d=True) + # initialize the past key values cache if it should be used presents = () if use_cache else None if use_cache and past_key_values is None: @@ -630,6 +638,7 @@ def forward( is_causal=self.is_causal, output_attentions=bool(output_attentions), attention_mask_in_length=attention_mask_in_length, + alibi_slopes=alibi_slopes, ) if presents is not None: presents += (present,) diff --git a/setup.py b/setup.py index 8122bbb14f..152c682a2d 100644 --- a/setup.py +++ b/setup.py @@ -98,7 +98,7 @@ 'xentropy-cuda-lib@git+https://github.com/HazyResearch/flash-attention.git@v1.0.9#subdirectory=csrc/xentropy', ] extra_deps['gpu-flash2'] = [ - 'flash-attn==2.3.6', + 'flash-attn==2.4.2', 'mosaicml-turbo==0.0.7', ] diff --git a/tests/models/layers/test_flash_attn.py b/tests/models/layers/test_flash_attn.py index acefd2c42d..3e1ec37b2e 100644 --- a/tests/models/layers/test_flash_attn.py +++ b/tests/models/layers/test_flash_attn.py @@ -6,7 +6,10 @@ import pytest import torch -from llmfoundry.models.layers.attention import (flash_attn_fn, +from llmfoundry.models.layers.attention import (attn_bias_shape, + build_attn_bias, + check_alibi_support, + flash_attn_fn, gen_slopes, is_flash_v2_installed, triton_flash_attn_fn) @@ -253,3 +256,111 @@ def test_sliding_window(sliding_window_size: int): ) <= 1e-2 + 1e-2 * torch.norm(key_2.grad) assert torch.norm(value_2.grad - value_1.grad # type: ignore ) <= 1e-2 + 1e-2 * torch.norm(value_2.grad) + + +@pytest.mark.gpu +@pytest.mark.skipif( + not check_alibi_support('flash'), + reason='ALiBi only supported by Flash Attention after v2.4.2.') +@pytest.mark.parametrize('n_heads', [1, 6, 8]) +def test_alibi_bias(n_heads: int): + # Test that sliding window attention works as expected. + dtype = torch.bfloat16 + device = 'cuda' + d = 128 + seqlen_1 = 8 + bsz = 2 + + query_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(dtype=dtype, + device=device) + query_1.requires_grad = True + key_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(dtype=dtype, + device=device) + key_1.requires_grad = True + value_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(dtype=dtype, + device=device) + value_1.requires_grad = True + alibi_slopes_1 = gen_slopes(n_heads=n_heads, + alibi_bias_max=8, + device=torch.device(device), + return_1d=True) + output_1, _, _ = flash_attn_fn(query=query_1, + key=key_1, + value=value_1, + n_heads=n_heads, + kv_n_heads=n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + attention_mask_in_length=None, + should_repeat_kv_for_gqa=True, + alibi_slopes=alibi_slopes_1) + + output_1.sum().backward() + + query_2 = query_1.detach().clone() + query_2.requires_grad = True + key_2 = key_1.detach().clone() + key_2.requires_grad = True + value_2 = value_1.detach().clone() + value_2.requires_grad = True + + def gen_bias(): + causal = True + bs = attn_bias_shape('triton', + n_heads, + seqlen_1, + True, + prefix_lm=False, + use_sequence_id=False, + causal=causal) + + attn_bias = torch.zeros(*bs, device=device) + attn_bias = build_attn_bias( + 'triton', + attn_bias, + n_heads, + seqlen_1, + causal=causal, + alibi=True, + alibi_bias_max=8, + ) + return attn_bias + + attn_bias_2 = gen_bias() + + output_2, _, _ = triton_flash_attn_fn( + query=query_2, + key=key_2, + value=value_2, + n_heads=n_heads, + kv_n_heads=n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=attn_bias_2, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + ) + + output_2.sum().backward() + + assert torch.allclose(output_1, output_2) + assert (query_2.grad is not None) and (query_1.grad is not None) + assert torch.norm(query_2.grad - + query_1.grad) <= 1e-2 + 1e-2 * torch.norm(query_2.grad) + assert (key_2.grad is not None) and (key_1.grad is not None) + assert torch.norm(key_2.grad - + key_1.grad) <= 1e-2 + 1e-2 * torch.norm(key_2.grad) + assert (value_2.grad is not None) and (value_1.grad is not None) + assert torch.norm(value_2.grad - + value_1.grad) <= 1e-2 + 1e-2 * torch.norm(value_2.grad) diff --git a/tests/models/layers/test_flash_triton_torch.py b/tests/models/layers/test_flash_triton_torch.py index 454fda311d..4ca5c7b668 100644 --- a/tests/models/layers/test_flash_triton_torch.py +++ b/tests/models/layers/test_flash_triton_torch.py @@ -6,7 +6,8 @@ from omegaconf import OmegaConf as om from llmfoundry.models.layers import attention -from llmfoundry.models.layers.attention import is_flash_v2_installed +from llmfoundry.models.layers.attention import (check_alibi_support, gen_slopes, + is_flash_v2_installed) from llmfoundry.models.mpt.modeling_mpt import (apply_sequence_id, gen_attention_mask_in_length, gen_rotary_embedding) @@ -20,7 +21,7 @@ def allclose_helper(t0: torch.Tensor, @pytest.mark.gpu -@pytest.mark.parametrize('attn_impl_0,attn_impl_1', [ +@pytest.mark.parametrize('attn_impl_0, attn_impl_1', [ ('flash', 'triton'), ('flash', 'torch'), ('triton', 'torch'), @@ -74,9 +75,9 @@ def test_attn_impl(attn_impl_0: str, """ alibi = pos_emb_config['alibi'] rope = pos_emb_config['rope'] - if alibi and (attn_impl_0 == 'flash' or attn_impl_1 == 'flash'): - pytest.skip('flash attn does not support alibi') - + if alibi and not (check_alibi_support(attn_impl_0) and + check_alibi_support(attn_impl_1)): + pytest.skip('flash attention below v2.4.2 does not support alibi.') if rope and (pos_emb_config['rope_impl'] == 'dail') and (not is_flash_v2_installed()): pytest.skip('dail implementation of rope requires flash attention 2.') @@ -177,6 +178,12 @@ def gen_bias(attn_impl: str): with torch.autocast(x0.device.type): attn_bias_0 = gen_bias(attn_impl_0) + alibi_slopes_0 = None + if alibi and attn_impl_0 == 'flash': + alibi_slopes_0 = gen_slopes(n_heads=cfg.n_heads, + alibi_bias_max=8, + device=torch.device(device), + return_1d=True) rotary_emb_w_meta_info = None if rope: rotary_embedding = gen_rotary_embedding( @@ -209,15 +216,23 @@ def gen_bias(attn_impl: str): attention_mask=attention_mask, rotary_emb_w_meta_info=rotary_emb_w_meta_info, is_causal=True, - attention_mask_in_length=attention_mask_in_length_0) + attention_mask_in_length=attention_mask_in_length_0, + alibi_slopes=alibi_slopes_0) attn_bias_1 = gen_bias(attn_impl_1) + alibi_slopes_1 = None + if alibi and attn_impl_1 == 'flash': + alibi_slopes_1 = gen_slopes(n_heads=cfg.n_heads, + alibi_bias_max=8, + device=torch.device(device), + return_1d=True) y1, _, _ = attn1(x1, past_key_value=None, attn_bias=attn_bias_1, attention_mask=attention_mask, rotary_emb_w_meta_info=rotary_emb_w_meta_info, is_causal=True, - attention_mask_in_length=attention_mask_in_length_1) + attention_mask_in_length=attention_mask_in_length_1, + alibi_slopes=alibi_slopes_1) y0 *= attention_mask.unsqueeze(-1) y1 *= attention_mask.unsqueeze(-1) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 7bccad089d..64a92f6cc6 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -27,7 +27,8 @@ from llmfoundry import COMPOSER_MODEL_REGISTRY, ComposerHFCausalLM from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss from llmfoundry.models.layers import NORM_CLASS_REGISTRY, build_alibi_bias -from llmfoundry.models.layers.attention import is_flash_v2_installed +from llmfoundry.models.layers.attention import (check_alibi_support, + is_flash_v2_installed) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM from llmfoundry.utils import build_tokenizer @@ -647,8 +648,8 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool, def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict): # Testing the output of concatenated sequence with sequence id masking vs individual sequences. alibi = pos_emb_config['alibi'] - if alibi and attention_impl == 'flash': - pytest.skip(f'alibi only implemented with torch and triton attention.') + if alibi and not check_alibi_support(attention_impl): + pytest.skip(f'flash attention below v2.4.2 does not support alibi.') rope = pos_emb_config['rope'] if rope and pos_emb_config[ @@ -766,8 +767,8 @@ def test_forward_with_padding(attention_impl: str, pos_emb_config: dict, tie_word_embeddings: bool): # Test that different placement of padding does not affect the output. alibi = pos_emb_config['alibi'] - if alibi and attention_impl == 'flash': - pytest.skip(f'alibi only implemented with torch and triton attention.') + if alibi and not check_alibi_support(attention_impl): + pytest.skip(f'flash attention below v2.4.2 does not support alibi.') rope = pos_emb_config['rope'] if rope and pos_emb_config[ @@ -1028,8 +1029,8 @@ def test_generate(attention_impl: str, precision: str, pos_emb_config: dict, tie_word_embeddings: bool): # Test that generate works, and produces the same output with or without # padding in the input. - if pos_emb_config['alibi'] and attention_impl == 'flash': - pytest.skip(f'alibi only implemented with torch and triton attention.') + if pos_emb_config['alibi'] and not check_alibi_support(attention_impl): + pytest.skip(f'flash attention below v2.4.2 does not support alibi.') if pos_emb_config['rope'] and pos_emb_config[ 'rope_impl'] == 'dail' and not is_flash_v2_installed(): @@ -1277,8 +1278,8 @@ def test_save_from_pretrained(tmp_path: pathlib.Path): }]) def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict): # Tests that the result is the same with or without padding when using kv caching - if pos_emb_config['alibi'] and attn_impl == 'flash': - pytest.skip(f'alibi only implemented with torch and triton attention.') + if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): + pytest.skip(f'flash attention below v2.4.2 does not support alibi.') if pos_emb_config['rope'] and pos_emb_config[ 'rope_impl'] == 'dail' and not is_flash_v2_installed(): pytest.skip( @@ -1414,8 +1415,8 @@ def test_forward_with_cache(attn_impl: str, pos_emb_config: dict, tie_word_embeddings: bool): # Test that model forward with and without the key-value cache produces the # same output. - if pos_emb_config['alibi'] and attn_impl == 'flash': - pytest.skip(f'alibi only implemented with torch and triton attention.') + if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): + pytest.skip(f'flash attention below v2.4.2 does not support alibi.') if pos_emb_config['rope'] and pos_emb_config[ 'rope_impl'] == 'dail' and not is_flash_v2_installed(): @@ -1551,8 +1552,8 @@ def test_forward_with_cache(attn_impl: str, pos_emb_config: dict, @pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_generate_with_past_kv(attn_impl: str, pos_emb_config: dict, tie_word_embeddings: bool): - if pos_emb_config['alibi'] and attn_impl == 'flash': - pytest.skip(f'alibi only implemented with torch and triton attention.') + if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): + pytest.skip(f'flash attention below v2.4.2 does not support alibi.') if pos_emb_config['rope'] and pos_emb_config[ 'rope_impl'] == 'dail' and not is_flash_v2_installed(): pytest.skip( @@ -1658,8 +1659,8 @@ def test_generation_kwargs_dont_crash(attn_impl: str, generation_kwargs: Dict[str, Any], pos_emb_config: dict, tie_word_embeddings: bool): - if pos_emb_config['alibi'] and attn_impl == 'flash': - pytest.skip(f'alibi only implemented with torch and triton attention.') + if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): + pytest.skip(f'flash attention below v2.4.2 does not support alibi.') if pos_emb_config['rope'] and pos_emb_config[ 'rope_impl'] == 'dail' and not is_flash_v2_installed(): @@ -1847,8 +1848,8 @@ def test_alibi_vs_hf(): }]) def test_forward_with_output_attentions_and_output_hidden_states( attn_impl: str, pos_emb_config: dict): - if pos_emb_config['alibi'] and attn_impl == 'flash': - pytest.skip(f'alibi only implemented with torch and triton attention.') + if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): + pytest.skip(f'flash attention below v2.4.2 does not support alibi.') if attn_impl in ['flash', 'triton']: pytest.skip(f'output_attentions only implemented with torch attention.') if pos_emb_config['rope'] and pos_emb_config[ From 5e85bd62d4ec9e13353ca95877477602ab538e5b Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Sat, 6 Jan 2024 12:25:17 -0800 Subject: [PATCH 03/31] Shashank/new benchmarks (#838) * adding new benchmarks * Update submit_benchmarks.py --------- Co-authored-by: Shashank Rajput --- scripts/train/benchmarking/README.md | 23 +++++ .../train/benchmarking/submit_benchmarks.py | 13 +-- scripts/train/benchmarking/sweep.py | 91 +++++++++++++++++++ 3 files changed, 121 insertions(+), 6 deletions(-) create mode 100644 scripts/train/benchmarking/sweep.py diff --git a/scripts/train/benchmarking/README.md b/scripts/train/benchmarking/README.md index c3c8bc1c74..5414cdc7bf 100644 --- a/scripts/train/benchmarking/README.md +++ b/scripts/train/benchmarking/README.md @@ -69,6 +69,29 @@ Our microbatching engine enables microbatch sizes that do not divde Global Batch [comment]: # (TODO: Update tables with torch 2.0 after next Composer release) +## H100 80GB BF16 (Large Scale, >= 128 GPUs) +| Model | SeqLen (T) | # GPUs | GPU | MFU | HFU | Model TFLOP | MicroBatchSize | GradAccum | GlobalBatchSize | Throughput (S/s) | Throughput (T/s) | Throughput (T/s/GPU) | GlobalBatchSize (T) | Precision | MP Mode | Sharding Strategy | Activation Checkpointing | Activation CPUOffload | NumParams | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| 70b | 2048 | 512 | h100_80gb | 41.25 | 55.0 | 408 | 8 | 1 | 4096 | 251 | 515636 | 1007 | 8388608 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 64862437376 | +| 70b | 2048 | 256 | h100_80gb | 42.42 | 56.56 | 419 | 8 | 1 | 2048 | 129 | 265149 | 1035 | 4194304 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 64862437376 | +| 70b | 2048 | 128 | h100_80gb | 43.36 | 57.81 | 428 | 8 | 1 | 1024 | 66 | 135490 | 1058 | 2097152 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 64862437376 | +| 30b | 2048 | 512 | h100_80gb | 40.27 | 53.69 | 398 | 8 | 1 | 4096 | 528 | 1083366 | 2115 | 8388608 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 29975214080 | +| 30b | 2048 | 256 | h100_80gb | 40.89 | 54.52 | 404 | 8 | 1 | 2048 | 268 | 550022 | 2148 | 4194304 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 29975214080 | +| 30b | 2048 | 128 | h100_80gb | 41.85 | 55.8 | 414 | 8 | 1 | 1024 | 137 | 281491 | 2199 | 2097152 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 29975214080 | +| 13b | 2048 | 512 | h100_80gb | 41.12 | 54.83 | 406 | 16 | 1 | 8192 | 1238 | 2535811 | 4952 | 16777216 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 12853954560 | +| 13b | 2048 | 256 | h100_80gb | 41.42 | 55.23 | 409 | 16 | 1 | 4096 | 623 | 1277214 | 4989 | 8388608 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 12853954560 | +| 13b | 2048 | 128 | h100_80gb | 42.18 | 56.24 | 417 | 16 | 1 | 2048 | 317 | 650264 | 5080 | 4194304 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 12853954560 | +| 7b | 2048 | 512 | h100_80gb | 42.2 | 42.2 | 417 | 6 | 1 | 3072 | 2417 | 4951479 | 9670 | 6291456 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 6658859008 | +| 7b | 2048 | 256 | h100_80gb | 44.15 | 44.15 | 436 | 6 | 1 | 1536 | 1264 | 2590548 | 10119 | 3145728 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 6658859008 | +| 7b | 2048 | 128 | h100_80gb | 45.71 | 45.71 | 452 | 6 | 1 | 768 | 654 | 1340830 | 10475 | 1572864 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 6658859008 | +| 3b | 2048 | 512 | h100_80gb | 39.24 | 39.24 | 388 | 8 | 1 | 4096 | 5416 | 11092218 | 21664 | 8388608 | amp_bf16 | DEFAULT | SHARD_GRAD_OP | False | False | 2651837440 | +| 3b | 2048 | 256 | h100_80gb | 41.25 | 41.25 | 408 | 8 | 1 | 2048 | 2846 | 5829686 | 22772 | 4194304 | amp_bf16 | DEFAULT | SHARD_GRAD_OP | False | False | 2651837440 | +| 3b | 2048 | 128 | h100_80gb | 42.43 | 42.43 | 419 | 8 | 1 | 1024 | 1463 | 2998098 | 23422 | 2097152 | amp_bf16 | DEFAULT | SHARD_GRAD_OP | False | False | 2651837440 | +| 1b | 2048 | 512 | h100_80gb | 36.65 | 36.65 | 362 | 12 | 1 | 6144 | 9959 | 20396905 | 39837 | 12582912 | amp_bf16 | DEFAULT | SHARD_GRAD_OP | False | False | 1315950592 | +| 1b | 2048 | 256 | h100_80gb | 39.15 | 39.15 | 387 | 12 | 1 | 3072 | 5319 | 10894207 | 42555 | 6291456 | amp_bf16 | DEFAULT | SHARD_GRAD_OP | False | False | 1315950592 | +| 1b | 2048 | 128 | h100_80gb | 40.6 | 40.6 | 401 | 12 | 1 | 1536 | 2757 | 5647854 | 44123 | 3145728 | amp_bf16 | DEFAULT | SHARD_GRAD_OP | False | False | 1315950592 | + + ## H100 80GB BF16 | Model | SeqLen (T) | # GPUs | GPU | MFU | HFU | Model TFLOP | MicroBatchSize | GradAccum | GlobalBatchSize | Throughput (S/s) | Throughput (T/s) | Throughput (T/s/GPU) | GlobalBatchSize (T) | Precision | MP Mode | Sharding Strategy | Activation Checkpointing | Activation CPUOffload | NumParams | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | diff --git a/scripts/train/benchmarking/submit_benchmarks.py b/scripts/train/benchmarking/submit_benchmarks.py index bfff10165a..a020745581 100644 --- a/scripts/train/benchmarking/submit_benchmarks.py +++ b/scripts/train/benchmarking/submit_benchmarks.py @@ -376,7 +376,7 @@ def get_integrations(project: str, git_integration.update({ 'integration_type': 'git_repo', 'git_repo': 'mosaicml/llm-foundry', - 'pip_install': '-e .[gpu]' + 'pip_install': '.[gpu-flash2]' }) integrations = [git_integration] @@ -398,8 +398,8 @@ def run_config(config: Tuple[str, int, int, str, str, int, str], { 'integration_type': 'git_repo', 'git_repo': 'mosaicml/llm-foundry', - 'git_branch': 'v0.4.0', - 'pip_install': '-e .[gpu]', + 'git_branch': 'main', + 'pip_install': '.[gpu-flash2]', }, { 'integration_type': 'wandb', @@ -411,7 +411,7 @@ def run_config(config: Tuple[str, int, int, str, str, int, str], command = '' if gpu_type == 'h100_80gb' and 'fp8' in precision: # Required for flash-attn and FP8 training command += f""" - pip install flash-attn==1.0.7 --no-build-isolation + pip install flash-attn==2.4.2 --no-build-isolation pip install git+https://github.com/NVIDIA/TransformerEngine.git@v0.10 pip uninstall install pydantic --yes pip install pydantic==1.9.0 @@ -420,11 +420,11 @@ def run_config(config: Tuple[str, int, int, str, str, int, str], if args.data_remote is None: command += f""" cd llm-foundry/scripts - python data_prep/convert_dataset_hf.py --dataset c4 --data_subset en --out_root ./my-copy-c4 --splits train_small val_small --concat_tokens {max_seq_len} --tokenizer gpt2 --eos_text '<|endoftext|>' + python data_prep/convert_dataset_hf.py --dataset c4 --data_subset en --out_root ./my-copy-c4 --splits train_small val_small --concat_tokens {max_seq_len} --eos_text '<|endoftext|>' composer train/train.py /mnt/config/parameters.yaml """ else: - command = f""" + command += f""" cd llm-foundry/scripts composer train/train.py /mnt/config/parameters.yaml """ @@ -487,6 +487,7 @@ def run_config(config: Tuple[str, int, int, str, str, int, str], print(f'Launching run {run.name}') else: print(f'run = {name}') + print(f'{config=}') def run_check_capacity(model_yaml: str, diff --git a/scripts/train/benchmarking/sweep.py b/scripts/train/benchmarking/sweep.py new file mode 100644 index 0000000000..441c9825f8 --- /dev/null +++ b/scripts/train/benchmarking/sweep.py @@ -0,0 +1,91 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import os + +# Define the arguments to sweep over + +base_args = [ + '--project tput', + '--image ', + '--git_branch main', + '--precisions bf16', + '--fsdp_config_mixed_precision PURE', + '--fsdp_config_limit_all_gathers true', + '--fsdp_config_forward_prefetch true', + '--fsdp_config_backward_prefetch BACKWARD_PRE', + '--activation_cpu_offload false', + '--seq_len_exp 11 11', + '--accum 1', + '--clusters ', + '--gpu_types h100_80gb', + '--data_remote ', + '--wandb true', + '--priority lowest', + '--RUN true', +] + +num_gpu_args_list = [ + [ + '--gpu_nums 128', + ], + [ + '--gpu_nums 256', + ], + [ + '--gpu_nums 512', + ], +] + +model_args_list = [ + [ + '--model_yamls 1b.yaml', + '--fsdp_config_activation_checkpointing false', + '--fsdp_config_shard_strategy SHARD_GRAD_OP', + '--microbatch_size 12', + '--attn_impl flash', + ], + [ + '--model_yamls 3b.yaml', + '--fsdp_config_activation_checkpointing false', + '--fsdp_config_shard_strategy SHARD_GRAD_OP', + '--microbatch_size 8', + '--attn_impl flash', + ], + [ + '--model_yamls 7b.yaml', + '--fsdp_config_activation_checkpointing false', + '--fsdp_config_shard_strategy FULL_SHARD', + '--microbatch_size 6', + '--attn_impl flash', + ], + [ + '--model_yamls 13b.yaml', + '--fsdp_config_activation_checkpointing true', + '--fsdp_config_shard_strategy FULL_SHARD', + '--microbatch_size 16', + '--attn_impl triton', + ], + [ + '--model_yamls 30b.yaml', + '--fsdp_config_activation_checkpointing true', + '--fsdp_config_shard_strategy FULL_SHARD', + '--microbatch_size 8', + '--attn_impl triton', + ], + [ + '--model_yamls 70b.yaml', + '--fsdp_config_activation_checkpointing true', + '--fsdp_config_shard_strategy FULL_SHARD', + '--microbatch_size 8', + '--attn_impl flash', + ], +] + +# Iterate over the arguments and call submit_benchmarks.py +for num_gpu_args in num_gpu_args_list: + for model_args in model_args_list: + command = ['python submit_benchmarks.py' + ] + base_args + num_gpu_args + model_args + command = ' '.join(command) + os.system(command) From 5b994884f3200abf94ab6e35d35098daa3f420f4 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Mon, 8 Jan 2024 11:37:39 -0800 Subject: [PATCH 04/31] Fix error when decoding a token in the id gap (or out of range) in a tiktoken tokenizer (#841) --- llmfoundry/tokenizers/tiktoken.py | 5 ++++- llmfoundry/utils/huggingface_hub_utils.py | 4 ++-- tests/tokenizers/test_tiktoken.py | 20 ++++++++++++++++++-- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/llmfoundry/tokenizers/tiktoken.py b/llmfoundry/tokenizers/tiktoken.py index 342b5c2ecf..2632985533 100644 --- a/llmfoundry/tokenizers/tiktoken.py +++ b/llmfoundry/tokenizers/tiktoken.py @@ -253,7 +253,10 @@ def _convert_token_to_id(self, token: str) -> Optional[int]: def _convert_id_to_token(self, index: int) -> Optional[str]: """Converts an index (integer) in a token (str) using the vocab.""" - return self.decoder.get(index) + # For tokens in either the gap in ids in the tokenizer, or beyond the range of the tokenizer, + # we return empty string. This matches the behavior of Hugging Face fast tokenizers, + # but not slow tokenizers. + return self.decoder.get(index, '') def convert_tokens_to_string(self, tokens: List[str]) -> str: """Converts a sequence of tokens (string) in a single string.""" diff --git a/llmfoundry/utils/huggingface_hub_utils.py b/llmfoundry/utils/huggingface_hub_utils.py index a74ab1cc35..07a9c3900e 100644 --- a/llmfoundry/utils/huggingface_hub_utils.py +++ b/llmfoundry/utils/huggingface_hub_utils.py @@ -59,7 +59,7 @@ def process_file( folder_path: str, flatten_imports_prefix: Sequence[str], ) -> list[str]: - with open(file_path, 'r') as f: + with open(file_path, 'r', encoding='utf-8') as f: source = f.read() parent_module_name = None @@ -102,7 +102,7 @@ def process_file( if new_filename == '__init__.py': new_filename = file_path.split('/')[-2] + '.py' new_file_path = os.path.join(folder_path, new_filename) - with open(new_file_path, 'w') as f: + with open(new_file_path, 'w', encoding='utf-8') as f: assert new_tree is not None f.write(ast.unparse(new_tree)) diff --git a/tests/tokenizers/test_tiktoken.py b/tests/tokenizers/test_tiktoken.py index 1ade2ea156..6a4d1c99c4 100644 --- a/tests/tokenizers/test_tiktoken.py +++ b/tests/tokenizers/test_tiktoken.py @@ -338,6 +338,7 @@ def test_additional_special_tokens(model_name: Optional[str], encoding_name: Optional[str], tmp_path: pathlib.Path): special_token_to_add = '<|im_start|>' + input_string = special_token_to_add + ' hello' wrapped_tokenizer, _, _ = get_tokenizers_for_testing( model_name, encoding_name, @@ -345,12 +346,15 @@ def test_additional_special_tokens(model_name: Optional[str], add_bos_token=False, add_eos_token=False, additional_special_tokens=[special_token_to_add]) - encoded_outputs = wrapped_tokenizer(special_token_to_add + - ' hello')['input_ids'] + encoded_outputs = wrapped_tokenizer(input_string)['input_ids'] assert encoded_outputs[0] == wrapped_tokenizer.vocab_size assert len(encoded_outputs) == 2 + decoded_outputs = wrapped_tokenizer.decode( + encoded_outputs, spaces_between_special_tokens=False) + assert decoded_outputs == input_string + @pytest.mark.parametrize('model_name,encoding_name', MODEL_ENCODING_NAME_PARAMETRIZATION) @@ -386,3 +390,15 @@ def test_chat_formatting(model_name: Optional[str], chat_str = wrapped_tokenizer.apply_chat_template( dict_chats, tokenize=False, add_generation_prompt=True) assert chat_str == MULTI_TURN_GENERATE_STRING[i] + + +def test_tiktoken_out_of_range(): + wrapped_tokenizer = TiktokenTokenizerWrapper(model_name='gpt-4',) + + # For gpt-4, 100256 is less than the vocab size, but is not a valid token + assert wrapped_tokenizer.decode([100256]) == '' + assert wrapped_tokenizer.decode(100256) == '' + + # For gpt-4, 1000000 is greater than the vocab size + assert wrapped_tokenizer.decode([1000000]) == '' + assert wrapped_tokenizer.decode(1000000) == '' From c03ca1a58903e426dbbaa9f70a6cf65f38ca3013 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Mon, 8 Jan 2024 13:46:55 -0800 Subject: [PATCH 05/31] Add use_tokenizer_eos option to convert text to mds script (#843) * Add use_tokenizer_eos option to convert text to mds script * Do store_true action for use_tokenizer_eos --- scripts/data_prep/convert_text_to_mds.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/scripts/data_prep/convert_text_to_mds.py b/scripts/data_prep/convert_text_to_mds.py index dc7c514d75..2218e575b2 100644 --- a/scripts/data_prep/convert_text_to_mds.py +++ b/scripts/data_prep/convert_text_to_mds.py @@ -47,18 +47,21 @@ def parse_args() -> Namespace: '--compression', type=str, default='zstd', + required=False, help='The compression algorithm to use for MDS writing', ) parser.add_argument( '--concat_tokens', type=int, + required=True, help='Convert text to tokens and concatenate up to this many tokens', ) parser.add_argument( '--tokenizer', type=str, + required=True, help='The name of the tokenizer to use', ) parser.add_argument( @@ -77,6 +80,13 @@ def parse_args() -> Namespace: help= 'The text to append to each example to separate concatenated examples', ) + parser.add_argument( + '--use_tokenizer_eos', + required=False, + action='store_true', + default=False, + help='Use the EOS text from the tokenizer.', + ) parser.add_argument( '--no_wrap', default=False, @@ -103,11 +113,15 @@ def parse_args() -> Namespace: parsed = parser.parse_args() - # Make sure we have needed concat options - if (parsed.concat_tokens is not None and - isinstance(parsed.concat_tokens, int) and parsed.tokenizer is None): - parser.error( - 'When setting --concat_tokens, you must specify a --tokenizer') + # Set eos token. + if parsed.use_tokenizer_eos: + # Ensure that eos text is not specified twice. + if parsed.eos_text is not None: + parser.error( + 'Cannot set --eos_text with --use_tokenizer_eos. Please specify one.' + ) + tokenizer = AutoTokenizer.from_pretrained(parsed.tokenizer) + parsed.eos_text = tokenizer.eos_token # now that we have validated them, change BOS/EOS to strings if parsed.bos_text is None: From 4772ba290371a3f187433647e0c1504321c5bb41 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Mon, 8 Jan 2024 15:32:07 -0800 Subject: [PATCH 06/31] Disable omegaconf environment variables (#845) --- scripts/train/train.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/scripts/train/train.py b/scripts/train/train.py index ef7a3b91db..8c9fcc0291 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -635,6 +635,11 @@ def main(cfg: DictConfig) -> Trainer: if __name__ == '__main__': yaml_path, args_list = sys.argv[1], sys.argv[2:] + + # Disable resolving environment variables through omegaconf. + om.clear_resolver('oc.env') + + # Load yaml and cli arguments. with open(yaml_path) as f: yaml_cfg = om.load(f) cli_cfg = om.from_cli(args_list) From ddba5c8d7395223a80019e1a2db90af753317195 Mon Sep 17 00:00:00 2001 From: Brian <23239305+b-chu@users.noreply.github.com> Date: Tue, 9 Jan 2024 12:42:48 -0500 Subject: [PATCH 07/31] Bump pre-commit version (#847) This matches composer's pre-commit version https://github.com/mosaicml/composer/blob/0aa95e05f092e23fa677400a25dbef4cd2098343/setup.py#L113 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 152c682a2d..3de80f2292 100644 --- a/setup.py +++ b/setup.py @@ -73,7 +73,7 @@ extra_deps = {} extra_deps['dev'] = [ - 'pre-commit>=2.18.1,<3', + 'pre-commit>=3.4.0,<4', 'pytest>=7.2.1,<8', 'pytest_codeblocks>=0.16.1,<0.17', 'pytest-cov>=4,<5', From d822630bad5dc87f2607e6f0f7944e6ad78f2d22 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Wed, 10 Jan 2024 11:44:23 -0800 Subject: [PATCH 08/31] Fix typo kwargs=>hf_kwargs (#853) * Fix kwargs typo, should be hf_kwargs * update test --- llmfoundry/data/finetuning/dataloader.py | 2 +- tests/data/test_dataloader.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 97725ce78c..4e1c3bbf9f 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -190,7 +190,7 @@ def build_finetuning_dataloader(cfg: DictConfig, max_seq_len=cfg.dataset.max_seq_len, preprocessing_fn=preprocessing_fn, tokenizer=tokenizer, - hf_kwargs=cfg.dataset.get('kwargs', {})) + hf_kwargs=cfg.dataset.get('hf_kwargs', {})) # Ensure dataset is large enough. if cfg.drop_last: diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 7f99eeda25..44d0442a87 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -326,7 +326,7 @@ def test_finetuning_dataloader_safe_load(hf_name: str, 'decoder_only_format': True, 'shuffle': True, 'safe_load': True, - 'kwargs': { + 'hf_kwargs': { 'revision': hf_revision } }, From c694121ee44a0698d856c283cc3398df465915cd Mon Sep 17 00:00:00 2001 From: Anna Date: Wed, 10 Jan 2024 12:48:54 -0800 Subject: [PATCH 09/31] Remove foundry time wrangling (#855) * Remove foundry time wrangling * fix format * better types --- llmfoundry/callbacks/async_eval_callback.py | 26 +++++++-------------- llmfoundry/callbacks/hf_checkpointer.py | 10 +++----- llmfoundry/optim/scheduler.py | 21 +++++++---------- tests/optim/test_scheduler.py | 2 +- 4 files changed, 20 insertions(+), 39 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 8352a9e283..4227448d87 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -158,31 +158,21 @@ def get_eval_parameters( def validate_interval(interval: Union[str, int, Time], save_interval: Union[str, int, Time]) -> Time: - if isinstance(save_interval, str): - new_save_interval: Time = Time.from_timestring(save_interval) - elif isinstance(save_interval, int): - new_save_interval: Time = Time(save_interval, TimeUnit.EPOCH) - else: - new_save_interval: Time = save_interval - - if isinstance(interval, str): - result: Time = Time.from_timestring(interval) - elif isinstance(interval, int): - result: Time = Time(interval, TimeUnit.EPOCH) - else: - result: Time = interval - - if new_save_interval.unit != result.unit: + + new_save_interval = Time.from_input(save_interval, TimeUnit.EPOCH) + async_interval = Time.from_input(interval, TimeUnit.EPOCH) + + if new_save_interval.unit != async_interval.unit: raise ValueError( 'Save interval and async eval interval must be in the same unit') - if result < new_save_interval: + if async_interval < new_save_interval: raise ValueError( 'Async eval interval must be equal or greater (less frequent) than save interval' ) - if result.value % new_save_interval.value != 0: + if async_interval.value % new_save_interval.value != 0: raise ValueError( 'Async eval interval must be a multiple of save interval') - return result + return async_interval class AsyncEval(Callback): diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 491d510188..8b139c0c25 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -96,14 +96,10 @@ def __init__( self.huggingface_folder_name_fstr = os.path.join( 'huggingface', huggingface_folder_name) - if isinstance(save_interval, str): - save_interval = Time.from_timestring(save_interval) - if isinstance(save_interval, int): - save_interval = Time(save_interval, TimeUnit.EPOCH) - - self.save_interval: Time = save_interval + self.save_interval: Time = Time.from_input(save_interval, + TimeUnit.EPOCH) self.check_interval = create_interval_scheduler( - save_interval, include_end_of_training=True) + self.save_interval, include_end_of_training=True) self.remote_ud = maybe_create_remote_uploader_downloader_from_uri( save_folder, loggers=[]) if self.remote_ud is not None: diff --git a/llmfoundry/optim/scheduler.py b/llmfoundry/optim/scheduler.py index 4a6d21c873..5598db28a1 100644 --- a/llmfoundry/optim/scheduler.py +++ b/llmfoundry/optim/scheduler.py @@ -16,24 +16,19 @@ def _raise_if_units_dont_match(time: Union[str, Time], t_max: Union[str, Time], name: str) -> None: - if isinstance(time, str): - time = Time.from_timestring(time) - if isinstance(t_max, str): - t_max = Time.from_timestring(t_max) + new_time = Time.from_input(time) + new_t_max = Time.from_input(t_max) - assert not isinstance(time, str) and not isinstance(t_max, str) - - if time.unit != t_max.unit: - raise ValueError(f'{time.unit=} does not match {t_max.unit=}.') + if new_time.unit != new_t_max.unit: + raise ValueError( + f'{name} (unit {new_time.unit=}) must match max_duration unit ({new_t_max.unit=}).' + ) def _raise_if_units_dur(time: Union[str, Time], name: str) -> None: - if isinstance(time, str): - time = Time.from_timestring(time) - - assert not isinstance(time, str) + new_time = Time.from_input(time) - if time.unit == TimeUnit('dur'): + if new_time.unit == TimeUnit('dur'): raise ValueError(f'{name} cannot be in units of "dur".') diff --git a/tests/optim/test_scheduler.py b/tests/optim/test_scheduler.py index 5b9d45a141..811088bd62 100644 --- a/tests/optim/test_scheduler.py +++ b/tests/optim/test_scheduler.py @@ -88,7 +88,7 @@ def test_scheduler_units_match_error(state_unit: str, warmup_unit: str, t_warmup=f'10{warmup_unit}', t_scale=f'10{scale_unit}', t_cooldown=f'10{cooldown_unit}') - with pytest.raises(ValueError, match='does not match'): + with pytest.raises(ValueError, match='must match'): _ = scheduler(state, 1.0) From a7c36bccf474441d5ab845fe0d91eca336e65d7f Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Wed, 10 Jan 2024 23:47:00 -0500 Subject: [PATCH 10/31] Minor cleanups (#858) * nits * logger * add log * lint --- llmfoundry/models/mpt/modeling_mpt.py | 4 ++-- llmfoundry/utils/config_utils.py | 14 +++----------- scripts/train/train.py | 6 +++++- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index e2274ffd6c..8b14c72f62 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -330,12 +330,12 @@ def __init__(self, config: MPTConfig): for module in self.modules(): if hasattr(module, 'bias') and isinstance( module.bias, nn.Parameter): - log.info(f'Removing bias ({module.bias}) from {module}.') + log.info(f'Removing bias from {module=}.') module.register_parameter('bias', None) # For transformer engine if hasattr(module, 'use_bias'): - log.info(f'Setting use_bias=False for {module}.') + log.info(f'Setting use_bias=False for {module=}.') module.use_bias = False log.debug(self) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 55576eaba0..29d78a0770 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -120,18 +120,10 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): # Set defaults for mixed initialization fsdp_config.setdefault('use_orig_params', False) fsdp_config.setdefault('load_monolith_rank0_only', True) - # Always set `sync_module_states` to True when using hybrid sharding - if fsdp_config is not None and \ - fsdp_config.get('sharding_strategy', 'FULL_SHARD') in ['HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'] \ - and not fsdp_config.get('sync_module_states', False): - warnings.warn( - ('Setting `sync_module_states = True` for FSDP. This is required ' - 'when using hybrid sharding.')) - fsdp_config['sync_module_states'] = True - - # no mixed precision needed for weights when they're already 16 bits + + # No mixed precision needed for weights when they're already 16 bits master_dtype = model_cfg.get('master_weights_dtype') - small_dtypes = ('bf16', 'f16', 'float16', 'bfloat16', 'amp_fp16', + small_dtypes = ('bf16', 'fp16', 'float16', 'bfloat16', 'amp_fp16', 'amp_bf16') if fsdp_config and master_dtype in small_dtypes: reduce_dtype = None diff --git a/scripts/train/train.py b/scripts/train/train.py index 8c9fcc0291..c3da1f1d3c 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -438,13 +438,17 @@ def main(cfg: DictConfig) -> Trainer: format= f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s' ) - logging.getLogger('llmfoundry').setLevel(python_log_level.upper()) + logging.getLogger('llmfoundry').setLevel( + python_log_level.upper()) # Foundry module + logging.getLogger(__name__).setLevel( + python_log_level.upper()) # Train script # Initialize context init_context = process_init_device(model_config, fsdp_config) logged_cfg.update({'fsdp_config': fsdp_config}, merge=True) # Build tokenizer + log.info('Building tokenizer...') tokenizer_name = tokenizer_config['name'] tokenizer_kwargs = tokenizer_config.get('kwargs', {}) tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) From 6de8c37f5476df726f7bf96b7a1e2c74e7612164 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 11 Jan 2024 06:29:31 -0800 Subject: [PATCH 11/31] Read UC delta table (#773) * initial commit * use databricks-sql to read delta table and convert to json * update * update * update * add mocked unittest * Fix lints * update * update * restructure code * Add timer for optimizing * Add db-connect * add wrapper * update * add install dbconnect * update * update * patch dbconnect to allow multiple return formats * update * add arrow * use compression * clean up * Add cluster rt check * Fix lints * remove patch.py for CI * update * update * updat * update * fix tests * fix lint * update * update * Add more tests * update * update * update * change to download_json * update * fix lints * Add decompressed option for arrow * format json to jsonl * Add comments * Make cf_collect_type global option * fix comments * fix lints * fix comments * Fix lints * change to use workspaceclient * Add CPT support * Rewire method assignment logic * Fix bug in stripping https * Add tests for rewired method assignment logic * Fix lints * Fix lints * Removed logger set_level * Remove pyspark. It conflicts with databricks-connect * Update the comment * skip cluster version check when cluster_id is serverless * Add use_serverless flag * update tests with use_serverless flag * Fix lints --------- Co-authored-by: Xiaohan Zhang --- scripts/data_prep/convert_delta_to_json.py | 517 ++++++++++++++++++ setup.py | 5 +- .../data_prep/test_convert_delta_to_json.py | 304 ++++++++++ 3 files changed, 825 insertions(+), 1 deletion(-) create mode 100644 scripts/data_prep/convert_delta_to_json.py create mode 100644 tests/a_scripts/data_prep/test_convert_delta_to_json.py diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py new file mode 100644 index 0000000000..8986849a42 --- /dev/null +++ b/scripts/data_prep/convert_delta_to_json.py @@ -0,0 +1,517 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +import re +import time +import urllib.parse +from argparse import ArgumentParser, Namespace +from collections import namedtuple +from concurrent.futures import ProcessPoolExecutor +from typing import Iterable, List, Optional, Tuple, Union +from uuid import uuid4 + +import google.protobuf.any_pb2 as any_pb2 +import lz4.frame +import pandas as pd +import pyarrow as pa +import pyspark.sql.connect.proto as pb2 +import pyspark.sql.connect.proto.cloud_pb2 as cloud_pb2 +import requests +from databricks import sql +from databricks.connect import DatabricksSession +from databricks.sdk import WorkspaceClient +from databricks.sql.client import Connection as Connection +from databricks.sql.client import Cursor as Cursor +from packaging import version +from pyspark.sql import SparkSession +from pyspark.sql.connect.client.core import SparkConnectClient +from pyspark.sql.connect.client.reattach import \ + ExecutePlanResponseReattachableIterator +from pyspark.sql.connect.dataframe import DataFrame +from pyspark.sql.dataframe import DataFrame as SparkDataFrame +from pyspark.sql.types import Row + +MINIMUM_DB_CONNECT_DBR_VERSION = '14.1.0' +MINIMUM_SQ_CONNECT_DBR_VERSION = '12.2.0' + +log = logging.getLogger(__name__) + +Result = namedtuple( + 'Result', ['url', 'row_count', 'compressed_size', 'uncompressed_size' + ]) # pyright: ignore + +# ``collect_as_cf`` is an addon new feature monkey patch on top of the DB Connect package. +# It allows the client to fetch the results in different formats from the server. +# To be able to use the code make sure this module is not overriden by DB Connect classes. + + +def to_cf(self: SparkConnectClient, + plan: pb2.Plan, + type: str = 'json') -> Tuple[List[Result], int, bool]: + """Executes the query plans and return as presigned URLS for cloud fetch. + + It can handle the current output formats that are supported by the server. + In contrast to the regular API methods of the client, this method does not + return the schema and drops all other responses. + + Args: + plan (pb2.Plan): The plan object to be executed by spark. + type (str): The output format of the result, supported formats are 'json', 'csv', and 'arrow'. + + Returns: + Tuple[List[Result], int, bool]: A tuple containing: + - A list of Result namedtuples, each containing a URL, row count, compressed size, + and uncompressed size of the part of the result. + - Total row count of all parts of the result. + - A boolean indicating whether the result has been truncated. + """ + req = self._execute_plan_request_with_metadata() + req.plan.CopyFrom(plan) + + # Add the request options + if type == 'json': + format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_JSON + elif type == 'csv': + format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_CSV + elif type == 'arrow': + format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_ARROW + else: + raise ValueError( + f'Only formats json, csv, and arrow are supported. Got invalid type {type}' + ) + + ro = cloud_pb2.ResultOptions( + type=cloud_pb2.ResultOptions.TYPE_CLOUD, + cloudOptions=cloud_pb2.ResultOptions.CloudOptions( + format=format, + useCompression=False, + )) + cloud_option = any_pb2.Any() + cloud_option.Pack(ro) + req.request_options.append( + pb2.ExecutePlanRequest.RequestOption(extension=cloud_option)) + + # Create the iterator + iterator = ExecutePlanResponseReattachableIterator(req, self._stub, + self._retry_policy, + self._builder.metadata()) + # Iterate over the response + result = [] + row_count = 0 + is_overflow = False + + for response in iterator: + if response.HasField('extension') and response.extension.Is( + cloud_pb2.CloudResultBatch.DESCRIPTOR): + batch = cloud_pb2.CloudResultBatch() + if not response.extension.Is(cloud_pb2.CloudResultBatch.DESCRIPTOR): + raise ValueError( + 'Response extension is not of type CloudResultBatch.') + response.extension.Unpack(batch) + result += [ + Result(b.url, b.row_count, b.compressed_size, + b.uncompressed_size) for b in batch.results + ] + row_count += sum(result.row_count for result in batch.results) + is_overflow |= batch.truncated + return result, row_count, is_overflow + + +SparkConnectClient.to_cf = to_cf # pyright: ignore + + +def collect_as_cf(self: DataFrame, + type: str = 'json') -> Tuple[List[Result], int, bool]: + """Collects DataFrame execution plan as presigned URLs. + + This method is a wrapper around the `to_cf` method of SparkConnectClient. It takes the + execution plan of the current DataFrame, converts it to a protocol buffer format, and then + uses the `to_cf` method to execute the plan and fetch results as presigned URLs. + + Args: + type (str): The output format of the result, supported formats are 'json', 'csv', and 'arrow'. + + Returns: + Tuple[List[Result], int, bool]: A tuple containing: + - A list of Result namedtuples, each containing a URL, row count, compressed size, + and uncompressed size of the part of the result. + - Total row count of all parts of the result. + - A boolean indicating whether the result is truncated or overflowed. + """ + query = self._plan.to_proto(self._session.client) # pyright: ignore + return self._session.client.to_cf(query, type) # pyright: ignore + + +DataFrame.collect_cf = collect_as_cf # pyright: ignore + + +def iterative_combine_jsons(json_directory: str, output_file: str) -> None: + """Combine jsonl files in json_directory into one big jsonl file. + + This function does not work for nested subdirectories. + + Args: + json_directory(str): directory containing the JSONL files + output_file(str): path to the output combined JSONL file + """ + json_files = [f for f in os.listdir(json_directory) if f.endswith('.jsonl')] + with open(output_file, 'w') as outfile: + for file_name in json_files: + with open(os.path.join(json_directory, file_name), 'r') as infile: + for line in infile: + outfile.write(line) + log.info('JSON files have been combined into a JSONL file.') + + +def run_query( + query: str, + method: str, + cursor: Optional[Cursor] = None, + spark: Optional[SparkSession] = None, + collect: bool = True +) -> Optional[Union[List[Row], DataFrame, SparkDataFrame]]: + """Run SQL query via databricks-connect or databricks-sql. + + Args: + query (str): sql query + method (str): select from dbsql and dbconnect + cursor (Optional[Cursor]): connection.cursor + spark (Optional[SparkSession]): spark session + collect (bool): whether to get the underlying data from spark dataframe + """ + if method == 'dbsql': + if cursor is None: + raise ValueError(f'cursor cannot be None if using method dbsql') + cursor.execute(query) + if collect: + return cursor.fetchall() + elif method == 'dbconnect': + if spark == None: + raise ValueError(f'sparkSession is required for dbconnect') + df = spark.sql(query) + if collect: + return df.collect() + return df + else: + raise ValueError(f'Unrecognized method: {method}') + + +def get_args(signed: List, json_output_path: str, columns: List) -> Iterable: + for i, r in enumerate(signed): + yield (i, r.url, json_output_path, columns) + + +def download(ipart: int, + url: str, + json_output_path: str, + columns: Optional[List] = None, + resp_format: str = 'arrow', + compressed: bool = False) -> None: + """Thread download presigned url and save to jsonl locally. + + Args: + ipart (int): presigned url id + url (str): presigned url + json_output_path (str): directory to save the ipart_th segment of dataframe + columns (list): schema to save to json + resp_format (str): whether to use arrow or json when collect + compressed (bool): if data is compressed before downloading. Need decompress if compressed=True. + """ + resp = requests.get(url) + if resp.status_code == 200: + if resp_format == 'json': + data = resp.json() + pd.DataFrame(data, columns=columns).to_json(os.path.join( + json_output_path, 'part_' + str(ipart) + '.jsonl'), + orient='records', + lines=True) + return + + # When resp_format is arrow: + if compressed: + # The data is lz4 compressed arrow format. + # Decompress the data + decompressed_data = lz4.frame.decompress(resp.content) + # Convert the decompressed data into a PyArrow table + reader = pa.ipc.open_stream(decompressed_data) + else: + reader = pa.ipc.open_stream(resp.content) + table = reader.read_all() + + # Convert the PyArrow table into a pandas DataFrame + df = table.to_pandas() + df.to_json(os.path.join(json_output_path, + 'part_' + str(ipart) + '.jsonl'), + orient='records', + lines=True, + force_ascii=False) + + +def download_starargs(args: Tuple) -> None: + return download(*args) + + +def fetch_data(method: str, cursor: Optional[Cursor], + sparkSession: Optional[SparkSession], start: int, end: int, + order_by: str, tablename: str, columns_str: str, + json_output_path: str) -> None: + """Fetches a specified range of rows from a given table to a json file. + + This function executes a SQL query to retrieve a range of rows, determined by 'start' and 'end' indexes, + from a specified table and column set. The fetched data is then exported as a JSON file. + + Args: + method (str): The method to use for fetching data, either 'dbconnect' or 'dbsql'. + cursor (Optional[Cursor]): The cursor object for executing queries in 'dbsql' method. + sparkSession (Optional[SparkSession]): The Spark session object for executing queries in 'dbconnect' method. + start (int): The starting index for row fetching. + end (int): The ending index for row fetching. + order_by (str): The column name to use for ordering the rows. + tablename (str): The name of the table from which to fetch the data. + columns_str (str): The string representation of the columns to select from the table. + json_output_path (str): The file path where the resulting JSON file will be saved. + + Returns: + None: The function doesn't return any value, but writes the result to a JSONL file. + """ + query = f""" + WITH NumberedRows AS ( + SELECT + *, + ROW_NUMBER() OVER (ORDER BY {order_by}) AS rn + FROM + {tablename} + ) + SELECT {columns_str} + FROM NumberedRows + WHERE rn BETWEEN {start+1} AND {end}""" + + if method == 'dbconnect': + spark_df = run_query(query, method, cursor, sparkSession, collect=False) + if spark_df is None: + raise RuntimeError( + f'Expect spark dataframe with {query} but got None') + pdf = spark_df.toPandas() # pyright: ignore + else: # method == 'dbsql': + ans = run_query(query, method, cursor, sparkSession, collect=True) + if ans is None: + raise RuntimeError(f'Got empty results with {query}') + records = [r.asDict() for r in ans] # pyright: ignore + pdf = pd.DataFrame.from_dict(records) + + pdf.to_json(os.path.join(json_output_path, f'part_{start+1}_{end}.jsonl'), + orient='records', + lines=True) + + +def fetch( + method: str, + tablename: str, + json_output_path: str, + batch_size: int = 1 << 30, + processes: int = 1, + sparkSession: Optional[SparkSession] = None, + dbsql: Optional[Connection] = None, +) -> None: + """Fetch UC delta table with databricks-connnect as JSONL. + + Args: + method (str): dbconnect or dbsql + tablename (str): catalog.scheme.tablename on UC + json_output_path (str): path to write the result json file to + batch_size (int): number of rows that dbsql fetches each time to avoid OOM + processes (int): max number of processes to use to parallelize the fetch + sparkSession (pyspark.sql.sparksession): spark session + dbsql (databricks.sql.connect): dbsql session + """ + cursor = dbsql.cursor() if dbsql is not None else None + + try: + ans = run_query(f'SELECT COUNT(*) FROM {tablename}', method, cursor, + sparkSession) + nrows = [row.asDict() for row in ans][0].popitem()[1] # pyright: ignore + log.info(f'total_rows = {nrows}') + except Exception as e: + raise RuntimeError( + f'Error in get total rows from {tablename}. Restart sparkSession and try again' + ) from e + + try: + ans = run_query(f'SHOW COLUMNS IN {tablename}', method, cursor, + sparkSession) + columns = [row.asDict().popitem()[1] for row in ans] # pyright: ignore + order_by = columns[0] + columns_str = ','.join(columns) + log.info(f'order by column {order_by}') + except Exception as e: + raise RuntimeError( + f'Error in get columns from {tablename}. Restart sparkSession and try again' + ) from e + + if method == 'dbconnect' and sparkSession is not None: + log.info('processes = ', processes) + df = sparkSession.table(tablename) + + # Running the query and collecting the data as arrow or json. + signed, _, _ = df.collect_cf('arrow') # pyright: ignore + log.info(f'len(signed) = {len(signed)}') + + args = get_args(signed, json_output_path, columns) + + # Stopping the SparkSession to avoid spilling connection state into the subprocesses. + sparkSession.stop() + + with ProcessPoolExecutor(max_workers=processes) as executor: + list(executor.map(download_starargs, args)) + + elif method == 'dbsql' and cursor is not None: + for start in range(0, nrows, batch_size): + log.warning(f'batch {start}') + end = min(start + batch_size, nrows) + fetch_data(method, cursor, sparkSession, start, end, order_by, + tablename, columns_str, json_output_path) + + if cursor is not None: + cursor.close() + + +def fetch_DT(args: Namespace) -> None: + """Fetch UC Delta Table to local as jsonl.""" + log.info(f'Start .... Convert delta to json') + + obj = urllib.parse.urlparse(args.json_output_path) + if obj.scheme != '': + raise ValueError( + f'Check the json_output_path and verify it is a local path!') + + if os.path.exists(args.json_output_path): + if not os.path.isdir(args.json_output_path) or os.listdir( + args.json_output_path): + raise RuntimeError( + f'A file or a folder {args.json_output_path} already exists and is not empty. Remove it and retry!' + ) + + os.makedirs(args.json_output_path, exist_ok=True) + + log.info(f'Directory {args.json_output_path} created.') + + method = 'dbsql' + dbsql = None + sparkSession = None + + if args.use_serverless: + method = 'dbconnect' + else: + w = WorkspaceClient() + res = w.clusters.get(cluster_id=args.cluster_id) + runtime_version = res.spark_version.split('-scala')[0].replace( + 'x-snapshot', '0').replace('x', '0') + if version.parse(runtime_version) < version.parse( + MINIMUM_SQ_CONNECT_DBR_VERSION): + raise ValueError( + f'The minium DBR version required is {MINIMUM_SQ_CONNECT_DBR_VERSION} but got {version.parse(runtime_version)}' + ) + + if args.http_path is None and version.parse( + runtime_version) >= version.parse( + MINIMUM_DB_CONNECT_DBR_VERSION): + method = 'dbconnect' + + if method == 'dbconnect': + try: + if args.use_serverless: + session_id = str(uuid4()) + sparkSession = DatabricksSession.builder.host( + args.DATABRICKS_HOST).token(args.DATABRICKS_TOKEN).header( + 'x-databricks-session-id', session_id).getOrCreate() + + else: + sparkSession = DatabricksSession.builder.remote( + host=args.DATABRICKS_HOST, + token=args.DATABRICKS_TOKEN, + cluster_id=args.cluster_id).getOrCreate() + + except Exception as e: + raise RuntimeError( + 'Failed to create databricks connection. Check hostname and access token!' + ) from e + else: + try: + dbsql = sql.connect( + server_hostname=re.compile(r'^https?://').sub( + '', args.DATABRICKS_HOST).strip( + ), # sqlconnect hangs if hostname starts with https + http_path=args.http_path, + access_token=args.DATABRICKS_TOKEN, + ) + except Exception as e: + raise RuntimeError( + 'Failed to create sql connection to db workspace. To use sql connect, you need to provide http_path and cluster_id!' + ) from e + + fetch(method, args.delta_table_name, args.json_output_path, args.batch_size, + args.processes, sparkSession, dbsql) + + if dbsql is not None: + dbsql.close() + + # combine downloaded jsonl into one big jsonl for IFT + iterative_combine_jsons( + args.json_output_path, + os.path.join(args.json_output_path, 'combined.jsonl')) + + +if __name__ == '__main__': + parser = ArgumentParser( + description= + 'Download delta table from UC and convert to json to save local') + parser.add_argument('--delta_table_name', + required=True, + type=str, + help='UC table ..') + parser.add_argument('--json_output_path', + required=True, + type=str, + help='Local path to save the converted json') + parser.add_argument('--http_path', + required=False, + type=str, + help='http_path is set then dbsql method is used') + parser.add_argument('--batch_size', + required=False, + type=int, + default=1 << 30, + help='row chunks to transmit a time to avoid OOM') + parser.add_argument('--processes', + required=False, + type=int, + default=os.cpu_count(), + help='number of processes allowed to use') + parser.add_argument( + '--cluster_id', + required=True, + type=str, + default=None, + help= + 'cluster id has runtime newer than 14.1.0 and access mode of either assigned or shared can use databricks-connect.' + ) + parser.add_argument( + '--use_serverless', + required=False, + type=bool, + default=False, + help= + 'Use serverless or not. Make sure the workspace is entitled with serverless' + ) + args = parser.parse_args() + + from databricks.sdk import WorkspaceClient + w = WorkspaceClient() + args.DATABRICKS_HOST = w.config.host + args.DATABRICKS_TOKEN = w.config.token + + tik = time.time() + fetch_DT(args) + log.info('Elapsed time', time.time() - tik) diff --git a/setup.py b/setup.py index 3de80f2292..5444352cf7 100644 --- a/setup.py +++ b/setup.py @@ -84,7 +84,10 @@ ] extra_deps['databricks'] = [ - 'mosaicml[databricks]>=0.17.2,<0.18', + 'mosaicml[databricks]>=0.17.1,<0.18', + 'databricks-sql-connector>=3,<4', + 'databricks-connect==14.1.0', + 'lz4>=4,<5', ] extra_deps['tensorboard'] = [ diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py new file mode 100644 index 0000000000..39bc5d8099 --- /dev/null +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -0,0 +1,304 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +# copyright 2022 mosaicml llm foundry authors +# spdx-license-identifier: apache-2.0 + +import unittest +from argparse import Namespace +from typing import Any +from unittest.mock import MagicMock, mock_open, patch + +from scripts.data_prep.convert_delta_to_json import (download, fetch_DT, + iterative_combine_jsons, + run_query) + + +class TestConverDeltaToJsonl(unittest.TestCase): + + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') + @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') + @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') + @patch('scripts.data_prep.convert_delta_to_json.fetch') + @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + def test_stream_delta_to_json(self, mock_workspace_client: Any, + mock_fetch: Any, mock_combine_jsons: Any, + mock_makedirs: Any, mock_sql_connect: Any): + + args = MagicMock() + args.delta_table_name = 'test_table' + args.json_output_path = '/path/to/jsonl' + args.DATABRICKS_HOST = 'test_host' + args.DATABRICKS_TOKEN = 'test_token' + args.http_path = 'test_path' + args.batch_size = 1000 + args.partitions = 1 + args.cluster_id = '1234' + args.debug = False + args.use_serverless = False + + mock_cluster_get = MagicMock() + mock_cluster_get.return_value = MagicMock( + spark_version='14.1.0-scala2.12') + mock_workspace_client.return_value.clusters.get = mock_cluster_get + + fetch_DT(args) + mock_sql_connect.assert_called_once_with(server_hostname='test_host', + http_path='test_path', + access_token='test_token') + mock_makedirs.assert_called_once_with('/path/to/jsonl', exist_ok=True) + mock_fetch.assert_called_once() + mock_combine_jsons.assert_called_once_with( + '/path/to/jsonl', '/path/to/jsonl/combined.jsonl') + + @patch('scripts.data_prep.convert_delta_to_json.os.listdir') + @patch('builtins.open', + new_callable=mock_open, + read_data='{"key": "value"}') + def test_iterative_combine_jsons(self, mock_file: Any, mock_listdir: Any): + mock_listdir.return_value = ['file1.jsonl', 'file2.jsonl'] + json_directory = '/fake/dir' + output_file = '/fake/output.jsonl' + + iterative_combine_jsons(json_directory, output_file) + + mock_listdir.assert_called_once_with(json_directory) + mock_file.assert_called() + """ + Diagnostic print + for call_args in mock_file().write.call_args_list: + print(call_args) + -------------------- + call('{') + call('"key"') + call(': ') + call('"value"') + call('}') + call('\n') + call('{') + call('"key"') + call(': ') + call('"value"') + call('}') + call('\n') + -------------------- + """ + self.assertEqual(mock_file().write.call_count, 2) + + @patch('scripts.data_prep.convert_delta_to_json.SparkSession') + def test_run_query_dbconnect(self, mock_spark: Any): + method = 'dbconnect' + mock_cursor = None + mock_spark.sql.return_value.collect.return_value = 'result' + + result = run_query('SELECT * FROM table', + method, + cursor=mock_cursor, + spark=mock_spark) + + mock_spark.sql.assert_called_once_with('SELECT * FROM table') + self.assertEqual(result, 'result') + + @patch('scripts.data_prep.convert_delta_to_json.Cursor') + def test_run_query_dbsql(self, mock_cursor: Any): + method = 'dbsql' + mock_cursor.fetchall.return_value = 'result' + mock_spark = None + + result = run_query('SELECT * FROM table', + method, + cursor=mock_cursor, + spark=mock_spark) + + mock_cursor.execute.assert_called_once_with('SELECT * FROM table') + self.assertEqual(result, 'result') + + @patch('scripts.data_prep.convert_delta_to_json.requests.get') + @patch('scripts.data_prep.convert_delta_to_json.pd.DataFrame.to_json') + @patch('scripts.data_prep.convert_delta_to_json.os.path.join', + return_value='/fake/path/part_1.jsonl') + @patch('scripts.data_prep.convert_delta_to_json.time.sleep' + ) # Mock sleep to speed up the test + def test_download_success(self, mock_sleep: Any, mock_join: Any, + mock_to_json: Any, mock_get: Any): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = [['val1.1', 'val1.2'], + ['val2.1', 'val2.2']] + mock_get.return_value = mock_response + + download(1, + 'http://fakeurl.com/data', + '/fake/path', ['A', 'B'], + resp_format='json') + + mock_get.assert_called_with('http://fakeurl.com/data') + mock_join.assert_called_with('/fake/path', 'part_1.jsonl') + mock_to_json.assert_called_with('/fake/path/part_1.jsonl', + orient='records', + lines=True) + + mock_get.assert_called_once_with('http://fakeurl.com/data') + + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') + @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') + @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') + @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') + @patch('scripts.data_prep.convert_delta_to_json.fetch') + def test_dbconnect_called(self, mock_fetch: Any, mock_combine_jsons: Any, + mock_makedirs: Any, mock_workspace_client: Any, + mock_databricks_session: Any, + mock_sql_connect: Any): + + args = MagicMock() + + args.delta_table_name = 'test_table' + args.json_output_path = '/path/to/jsonl' + # Execute function with http_path=None (should use dbconnect) + args.http_path = None + args.cluster_id = '1234' + args.DATABRICKS_HOST = 'host' + args.DATABRICKS_TOKEN = 'token' + args.use_serverless = False + + mock_cluster_response = Namespace(spark_version='14.1.0-scala2.12') + mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response + + mock_remote = MagicMock() + mock_remote.getOrCreate.return_value = MagicMock( + ) # Mock return value for getOrCreate + mock_databricks_session.builder.remote.return_value = mock_remote + + fetch_DT(args) + mock_databricks_session.builder.remote.assert_called_once_with( + host=args.DATABRICKS_HOST, + token=args.DATABRICKS_TOKEN, + cluster_id=args.cluster_id) + + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') + @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') + @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') + @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') + @patch('scripts.data_prep.convert_delta_to_json.fetch') + def test_sqlconnect_called_dbr13(self, mock_fetch: Any, + mock_combine_jsons: Any, + mock_makedirs: Any, + mock_workspace_client: Any, + mock_databricks_session: Any, + mock_sql_connect: Any): + + args = MagicMock() + + args.delta_table_name = 'test_table' + args.json_output_path = '/path/to/jsonl' + # Execute function with http_path=None (should use dbconnect) + args.http_path = 'test_path' + args.cluster_id = '1234' + args.DATABRICKS_HOST = 'host' + args.DATABRICKS_TOKEN = 'token' + args.use_serverless = False + + mock_cluster_response = Namespace(spark_version='13.0.0-scala2.12') + mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response + + fetch_DT(args) + mock_sql_connect.assert_called_once_with( + server_hostname=args.DATABRICKS_HOST, + http_path=args.http_path, + access_token=args.DATABRICKS_TOKEN) + + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') + @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') + @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') + @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') + @patch('scripts.data_prep.convert_delta_to_json.fetch') + def test_sqlconnect_called_dbr14(self, mock_fetch: Any, + mock_combine_jsons: Any, + mock_makedirs: Any, + mock_workspace_client: Any, + mock_databricks_session: Any, + mock_sql_connect: Any): + + args = MagicMock() + + args.delta_table_name = 'test_table' + args.json_output_path = '/path/to/jsonl' + # Execute function with http_path=None (should use dbconnect) + args.http_path = 'test_path' + args.cluster_id = '1234' + args.DATABRICKS_HOST = 'host' + args.DATABRICKS_TOKEN = 'token' + args.use_serverless = False + + mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') + mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response + + fetch_DT(args) + mock_sql_connect.assert_called_once_with( + server_hostname=args.DATABRICKS_HOST, + http_path=args.http_path, + access_token=args.DATABRICKS_TOKEN) + + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') + @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') + @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') + @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') + @patch('scripts.data_prep.convert_delta_to_json.fetch') + def test_sqlconnect_called_https(self, mock_fetch: Any, + mock_combine_jsons: Any, + mock_makedirs: Any, + mock_workspace_client: Any, + mock_databricks_session: Any, + mock_sql_connect: Any): + + args = MagicMock() + + args.delta_table_name = 'test_table' + args.json_output_path = '/path/to/jsonl' + # Execute function with http_path=None (should use dbconnect) + args.http_path = 'test_path' + args.cluster_id = '1234' + args.DATABRICKS_HOST = 'https://test-host' + args.DATABRICKS_TOKEN = 'token' + args.use_serverless = False + + mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') + mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response + + fetch_DT(args) + mock_sql_connect.assert_called_once_with( + server_hostname='test-host', + http_path=args.http_path, + access_token=args.DATABRICKS_TOKEN) + + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') + @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') + @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') + @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') + @patch('scripts.data_prep.convert_delta_to_json.fetch') + def test_serverless(self, mock_fetch: Any, mock_combine_jsons: Any, + mock_makedirs: Any, mock_workspace_client: Any, + mock_databricks_session: Any, mock_sql_connect: Any): + + args = MagicMock() + + args.delta_table_name = 'test_table' + args.json_output_path = '/path/to/jsonl' + # Execute function with http_path=None (should use dbconnect) + args.http_path = 'test_path' + args.cluster_id = '1234' + args.DATABRICKS_HOST = 'https://test-host' + args.DATABRICKS_TOKEN = 'token' + args.use_serverless = True + + mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') + mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response + + fetch_DT(args) + assert not mock_sql_connect.called + assert not mock_databricks_session.builder.remote.called From fa8f3d96e7bf53f8e21ed0fefc8dfba1bb269c18 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Thu, 11 Jan 2024 14:43:58 -0500 Subject: [PATCH 12/31] remove fused layernorm (#859) --- llmfoundry/utils/builders.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 404ad604ab..75438b895e 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -243,8 +243,6 @@ def build_algorithm(name: str, kwargs: Dict[str, Any]) -> Algorithm: return algorithms.GradientClipping(**kwargs) elif name == 'alibi': return algorithms.Alibi(**kwargs) - elif name == 'fused_layernorm': - return algorithms.FusedLayerNorm(**kwargs) elif name == 'gated_linear_units': return algorithms.GatedLinearUnits(**kwargs) elif name == 'low_precision_layernorm': From da3bea1487c589631ac95b3be346891f09d993cf Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 11 Jan 2024 18:41:23 -0800 Subject: [PATCH 13/31] Remove hardcoded combined.jsonl with a flag (#861) * Remove hardcoded combined.jsonl with a flag * update * change output_json_path output_json_folder --------- Co-authored-by: Xiaohan Zhang --- scripts/data_prep/convert_delta_to_json.py | 61 +++++++++++-------- .../data_prep/test_convert_delta_to_json.py | 13 ++-- 2 files changed, 42 insertions(+), 32 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 8986849a42..029ce7f5c3 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -198,14 +198,14 @@ def run_query( raise ValueError(f'Unrecognized method: {method}') -def get_args(signed: List, json_output_path: str, columns: List) -> Iterable: +def get_args(signed: List, json_output_folder: str, columns: List) -> Iterable: for i, r in enumerate(signed): - yield (i, r.url, json_output_path, columns) + yield (i, r.url, json_output_folder, columns) def download(ipart: int, url: str, - json_output_path: str, + json_output_folder: str, columns: Optional[List] = None, resp_format: str = 'arrow', compressed: bool = False) -> None: @@ -214,7 +214,7 @@ def download(ipart: int, Args: ipart (int): presigned url id url (str): presigned url - json_output_path (str): directory to save the ipart_th segment of dataframe + json_output_folder (str): directory to save the ipart_th segment of dataframe columns (list): schema to save to json resp_format (str): whether to use arrow or json when collect compressed (bool): if data is compressed before downloading. Need decompress if compressed=True. @@ -224,7 +224,7 @@ def download(ipart: int, if resp_format == 'json': data = resp.json() pd.DataFrame(data, columns=columns).to_json(os.path.join( - json_output_path, 'part_' + str(ipart) + '.jsonl'), + json_output_folder, 'part_' + str(ipart) + '.jsonl'), orient='records', lines=True) return @@ -242,7 +242,7 @@ def download(ipart: int, # Convert the PyArrow table into a pandas DataFrame df = table.to_pandas() - df.to_json(os.path.join(json_output_path, + df.to_json(os.path.join(json_output_folder, 'part_' + str(ipart) + '.jsonl'), orient='records', lines=True, @@ -256,7 +256,7 @@ def download_starargs(args: Tuple) -> None: def fetch_data(method: str, cursor: Optional[Cursor], sparkSession: Optional[SparkSession], start: int, end: int, order_by: str, tablename: str, columns_str: str, - json_output_path: str) -> None: + json_output_folder: str) -> None: """Fetches a specified range of rows from a given table to a json file. This function executes a SQL query to retrieve a range of rows, determined by 'start' and 'end' indexes, @@ -271,7 +271,7 @@ def fetch_data(method: str, cursor: Optional[Cursor], order_by (str): The column name to use for ordering the rows. tablename (str): The name of the table from which to fetch the data. columns_str (str): The string representation of the columns to select from the table. - json_output_path (str): The file path where the resulting JSON file will be saved. + json_output_folder (str): The file path where the resulting JSON file will be saved. Returns: None: The function doesn't return any value, but writes the result to a JSONL file. @@ -301,7 +301,7 @@ def fetch_data(method: str, cursor: Optional[Cursor], records = [r.asDict() for r in ans] # pyright: ignore pdf = pd.DataFrame.from_dict(records) - pdf.to_json(os.path.join(json_output_path, f'part_{start+1}_{end}.jsonl'), + pdf.to_json(os.path.join(json_output_folder, f'part_{start+1}_{end}.jsonl'), orient='records', lines=True) @@ -309,7 +309,7 @@ def fetch_data(method: str, cursor: Optional[Cursor], def fetch( method: str, tablename: str, - json_output_path: str, + json_output_folder: str, batch_size: int = 1 << 30, processes: int = 1, sparkSession: Optional[SparkSession] = None, @@ -320,7 +320,7 @@ def fetch( Args: method (str): dbconnect or dbsql tablename (str): catalog.scheme.tablename on UC - json_output_path (str): path to write the result json file to + json_output_folder (str): path to write the result json file to batch_size (int): number of rows that dbsql fetches each time to avoid OOM processes (int): max number of processes to use to parallelize the fetch sparkSession (pyspark.sql.sparksession): spark session @@ -358,7 +358,7 @@ def fetch( signed, _, _ = df.collect_cf('arrow') # pyright: ignore log.info(f'len(signed) = {len(signed)}') - args = get_args(signed, json_output_path, columns) + args = get_args(signed, json_output_folder, columns) # Stopping the SparkSession to avoid spilling connection state into the subprocesses. sparkSession.stop() @@ -371,7 +371,7 @@ def fetch( log.warning(f'batch {start}') end = min(start + batch_size, nrows) fetch_data(method, cursor, sparkSession, start, end, order_by, - tablename, columns_str, json_output_path) + tablename, columns_str, json_output_folder) if cursor is not None: cursor.close() @@ -381,21 +381,24 @@ def fetch_DT(args: Namespace) -> None: """Fetch UC Delta Table to local as jsonl.""" log.info(f'Start .... Convert delta to json') - obj = urllib.parse.urlparse(args.json_output_path) + obj = urllib.parse.urlparse(args.json_output_folder) if obj.scheme != '': raise ValueError( - f'Check the json_output_path and verify it is a local path!') + f'Check the json_output_folder and verify it is a local path!') - if os.path.exists(args.json_output_path): - if not os.path.isdir(args.json_output_path) or os.listdir( - args.json_output_path): + if os.path.exists(args.json_output_folder): + if not os.path.isdir(args.json_output_folder) or os.listdir( + args.json_output_folder): raise RuntimeError( - f'A file or a folder {args.json_output_path} already exists and is not empty. Remove it and retry!' + f'A file or a folder {args.json_output_folder} already exists and is not empty. Remove it and retry!' ) - os.makedirs(args.json_output_path, exist_ok=True) + os.makedirs(args.json_output_folder, exist_ok=True) - log.info(f'Directory {args.json_output_path} created.') + if not args.json_output_filename.endswith('.jsonl'): + raise ValueError('json_output_filename needs to be a jsonl file') + + log.info(f'Directory {args.json_output_folder} created.') method = 'dbsql' dbsql = None @@ -451,16 +454,16 @@ def fetch_DT(args: Namespace) -> None: 'Failed to create sql connection to db workspace. To use sql connect, you need to provide http_path and cluster_id!' ) from e - fetch(method, args.delta_table_name, args.json_output_path, args.batch_size, - args.processes, sparkSession, dbsql) + fetch(method, args.delta_table_name, args.json_output_folder, + args.batch_size, args.processes, sparkSession, dbsql) if dbsql is not None: dbsql.close() # combine downloaded jsonl into one big jsonl for IFT iterative_combine_jsons( - args.json_output_path, - os.path.join(args.json_output_path, 'combined.jsonl')) + args.json_output_folder, + os.path.join(args.json_output_folder, args.json_output_filename)) if __name__ == '__main__': @@ -471,7 +474,7 @@ def fetch_DT(args: Namespace) -> None: required=True, type=str, help='UC table ..
') - parser.add_argument('--json_output_path', + parser.add_argument('--json_output_folder', required=True, type=str, help='Local path to save the converted json') @@ -505,6 +508,12 @@ def fetch_DT(args: Namespace) -> None: help= 'Use serverless or not. Make sure the workspace is entitled with serverless' ) + parser.add_argument( + '--json_output_filename', + required=False, + type=str, + default='train-00000-of-00001.jsonl', + help='The combined final jsonl that combines all partitioned jsonl') args = parser.parse_args() from databricks.sdk import WorkspaceClient diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index 39bc5d8099..b366d8635a 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -27,7 +27,7 @@ def test_stream_delta_to_json(self, mock_workspace_client: Any, args = MagicMock() args.delta_table_name = 'test_table' - args.json_output_path = '/path/to/jsonl' + args.json_output_folder = '/path/to/jsonl' args.DATABRICKS_HOST = 'test_host' args.DATABRICKS_TOKEN = 'test_token' args.http_path = 'test_path' @@ -36,6 +36,7 @@ def test_stream_delta_to_json(self, mock_workspace_client: Any, args.cluster_id = '1234' args.debug = False args.use_serverless = False + args.json_output_filename = 'combined.jsonl' mock_cluster_get = MagicMock() mock_cluster_get.return_value = MagicMock( @@ -154,7 +155,7 @@ def test_dbconnect_called(self, mock_fetch: Any, mock_combine_jsons: Any, args = MagicMock() args.delta_table_name = 'test_table' - args.json_output_path = '/path/to/jsonl' + args.json_output_folder = '/path/to/jsonl' # Execute function with http_path=None (should use dbconnect) args.http_path = None args.cluster_id = '1234' @@ -192,7 +193,7 @@ def test_sqlconnect_called_dbr13(self, mock_fetch: Any, args = MagicMock() args.delta_table_name = 'test_table' - args.json_output_path = '/path/to/jsonl' + args.json_output_folder = '/path/to/jsonl' # Execute function with http_path=None (should use dbconnect) args.http_path = 'test_path' args.cluster_id = '1234' @@ -225,7 +226,7 @@ def test_sqlconnect_called_dbr14(self, mock_fetch: Any, args = MagicMock() args.delta_table_name = 'test_table' - args.json_output_path = '/path/to/jsonl' + args.json_output_folder = '/path/to/jsonl' # Execute function with http_path=None (should use dbconnect) args.http_path = 'test_path' args.cluster_id = '1234' @@ -258,7 +259,7 @@ def test_sqlconnect_called_https(self, mock_fetch: Any, args = MagicMock() args.delta_table_name = 'test_table' - args.json_output_path = '/path/to/jsonl' + args.json_output_folder = '/path/to/jsonl' # Execute function with http_path=None (should use dbconnect) args.http_path = 'test_path' args.cluster_id = '1234' @@ -288,7 +289,7 @@ def test_serverless(self, mock_fetch: Any, mock_combine_jsons: Any, args = MagicMock() args.delta_table_name = 'test_table' - args.json_output_path = '/path/to/jsonl' + args.json_output_folder = '/path/to/jsonl' # Execute function with http_path=None (should use dbconnect) args.http_path = 'test_path' args.cluster_id = '1234' From 936e3a1bd5f16fa3c2510c1af7753493635498be Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Thu, 11 Jan 2024 22:49:07 -0500 Subject: [PATCH 14/31] bump (#828) --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 5444352cf7..2c4a05f396 100644 --- a/setup.py +++ b/setup.py @@ -96,13 +96,13 @@ extra_deps['gpu'] = [ 'flash-attn==1.0.9', - 'mosaicml-turbo==0.0.7', + 'mosaicml-turbo==0.0.8', # PyPI does not support direct dependencies, so we remove this line before uploading from PyPI 'xentropy-cuda-lib@git+https://github.com/HazyResearch/flash-attention.git@v1.0.9#subdirectory=csrc/xentropy', ] extra_deps['gpu-flash2'] = [ 'flash-attn==2.4.2', - 'mosaicml-turbo==0.0.7', + 'mosaicml-turbo==0.0.8', ] extra_deps['peft'] = [ From 6517a307ce3f484792e4f23dea05e342b549a6ef Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Fri, 12 Jan 2024 12:37:00 -0500 Subject: [PATCH 15/31] Always initialize dist (#864) * fix dev * lint * remove gpu --- tests/a_scripts/eval/test_eval.py | 1 - tests/fixtures/autouse.py | 6 +----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/a_scripts/eval/test_eval.py b/tests/a_scripts/eval/test_eval.py index e8d86903dc..c9dfb88732 100644 --- a/tests/a_scripts/eval/test_eval.py +++ b/tests/a_scripts/eval/test_eval.py @@ -71,7 +71,6 @@ def test_icl_eval(eval_cfg: Union[om.ListConfig, om.DictConfig], capfd: Any, assert expected_results in out -@pytest.mark.gpu def test_loader_eval(capfd: Any, mock_saved_model_path: Any, tmp_path: pathlib.Path): diff --git a/tests/fixtures/autouse.py b/tests/fixtures/autouse.py index 75caa6c941..ccbe1b69f7 100644 --- a/tests/fixtures/autouse.py +++ b/tests/fixtures/autouse.py @@ -17,12 +17,8 @@ @pytest.fixture(autouse=True) def initialize_dist(request: pytest.FixtureRequest): """Initialize the default PyTorch distributed process group for tests.""" - # should we just always initialize dist like in train.py? - _default = pytest.mark.world_size(1).mark - world_size = request.node.get_closest_marker('world_size', _default).args[0] gpu = request.node.get_closest_marker('gpu') - if world_size > 1: - dist.initialize_dist(get_device('gpu' if gpu is not None else 'cpu')) + dist.initialize_dist(get_device('gpu' if gpu is not None else 'cpu')) @pytest.fixture(autouse=True) From d05c0992e4ee86b3e9b0d120c39214e151a26f0d Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Fri, 12 Jan 2024 14:19:15 -0500 Subject: [PATCH 16/31] Logs upload URI (#850) * fix style etc. * fix * fix fix * fix fix fix * fix fix fix fix * removed unused dummy func * deleted tests to make the tests pass * tried adding back some tests to see if it triggers the issue * add test_hf_checkpointer.py but remove references to MPT * fix? * fixed test cases overlapping in strange side-effecty ways --- llmfoundry/callbacks/hf_checkpointer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 8b139c0c25..b50d81d09e 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -241,11 +241,16 @@ def _save_checkpoint(self, state: State, logger: Logger): ) if self.remote_ud is not None: - log.info(f'Uploading HuggingFace formatted checkpoint') for filename in os.listdir(temp_save_dir): + remote_file_name = os.path.join(save_dir, filename) + remote_file_uri = self.remote_ud.remote_backend.get_uri( + remote_file_name) + log.info( + f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}' + ) self.remote_ud.upload_file( state=state, - remote_file_name=os.path.join(save_dir, filename), + remote_file_name=remote_file_name, file_path=Path(os.path.join(temp_save_dir, filename)), overwrite=self.overwrite, From b69318e81b2addef170325edac9b627635033210 Mon Sep 17 00:00:00 2001 From: Nancy Hung Date: Fri, 12 Jan 2024 23:06:28 -0800 Subject: [PATCH 17/31] Delta to JSONL conversion script cleanup and bug fix (#868) * Small test change * small cleanups * lint and precommit * lint and precommit * comments * another one * pr suggestion and use input param not args --- scripts/data_prep/convert_delta_to_json.py | 110 +++++++++++++-------- 1 file changed, 70 insertions(+), 40 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 029ce7f5c3..326b8e912f 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -33,8 +33,8 @@ from pyspark.sql.dataframe import DataFrame as SparkDataFrame from pyspark.sql.types import Row -MINIMUM_DB_CONNECT_DBR_VERSION = '14.1.0' -MINIMUM_SQ_CONNECT_DBR_VERSION = '12.2.0' +MINIMUM_DB_CONNECT_DBR_VERSION = '14.1' +MINIMUM_SQ_CONNECT_DBR_VERSION = '12.2' log = logging.getLogger(__name__) @@ -377,64 +377,61 @@ def fetch( cursor.close() -def fetch_DT(args: Namespace) -> None: - """Fetch UC Delta Table to local as jsonl.""" - log.info(f'Start .... Convert delta to json') - - obj = urllib.parse.urlparse(args.json_output_folder) - if obj.scheme != '': - raise ValueError( - f'Check the json_output_folder and verify it is a local path!') - - if os.path.exists(args.json_output_folder): - if not os.path.isdir(args.json_output_folder) or os.listdir( - args.json_output_folder): - raise RuntimeError( - f'A file or a folder {args.json_output_folder} already exists and is not empty. Remove it and retry!' - ) - - os.makedirs(args.json_output_folder, exist_ok=True) - - if not args.json_output_filename.endswith('.jsonl'): - raise ValueError('json_output_filename needs to be a jsonl file') - - log.info(f'Directory {args.json_output_folder} created.') +def validate_and_get_cluster_info(cluster_id: str, + databricks_host: str, + databricks_token: str, + http_path: Optional[str], + use_serverless: bool = False) -> tuple: + """Validate and get cluster info for running the Delta to JSONL conversion. + Args: + cluster_id (str): cluster id to validate and fetch additional info for + databricks_host (str): databricks host name + databricks_token (str): databricks auth token + http_path (Optional[str]): http path to use for sql connect + use_serverless (bool): whether to use serverless or not + """ method = 'dbsql' dbsql = None sparkSession = None - if args.use_serverless: + if use_serverless: method = 'dbconnect' else: w = WorkspaceClient() - res = w.clusters.get(cluster_id=args.cluster_id) - runtime_version = res.spark_version.split('-scala')[0].replace( - 'x-snapshot', '0').replace('x', '0') + res = w.clusters.get(cluster_id=cluster_id) + if res is None: + raise ValueError( + f'Cluster id {cluster_id} does not exist. Check cluster id and try again!' + ) + stripped_runtime = re.sub( + r'[a-zA-Z]', '', + res.spark_version.split('-scala')[0].replace('x-snapshot', '')) + runtime_version = re.sub(r'.-+$', '', stripped_runtime) if version.parse(runtime_version) < version.parse( MINIMUM_SQ_CONNECT_DBR_VERSION): raise ValueError( f'The minium DBR version required is {MINIMUM_SQ_CONNECT_DBR_VERSION} but got {version.parse(runtime_version)}' ) - if args.http_path is None and version.parse( + if http_path is None and version.parse( runtime_version) >= version.parse( MINIMUM_DB_CONNECT_DBR_VERSION): method = 'dbconnect' if method == 'dbconnect': try: - if args.use_serverless: + if use_serverless: session_id = str(uuid4()) sparkSession = DatabricksSession.builder.host( - args.DATABRICKS_HOST).token(args.DATABRICKS_TOKEN).header( + databricks_host).token(databricks_token).header( 'x-databricks-session-id', session_id).getOrCreate() else: sparkSession = DatabricksSession.builder.remote( - host=args.DATABRICKS_HOST, - token=args.DATABRICKS_TOKEN, - cluster_id=args.cluster_id).getOrCreate() + host=databricks_host, + token=databricks_token, + cluster_id=cluster_id).getOrCreate() except Exception as e: raise RuntimeError( @@ -444,15 +441,47 @@ def fetch_DT(args: Namespace) -> None: try: dbsql = sql.connect( server_hostname=re.compile(r'^https?://').sub( - '', args.DATABRICKS_HOST).strip( + '', databricks_host).strip( ), # sqlconnect hangs if hostname starts with https - http_path=args.http_path, - access_token=args.DATABRICKS_TOKEN, + http_path=http_path, + access_token=databricks_token, ) except Exception as e: raise RuntimeError( 'Failed to create sql connection to db workspace. To use sql connect, you need to provide http_path and cluster_id!' ) from e + return method, dbsql, sparkSession + + +def fetch_DT(args: Namespace) -> None: + """Fetch UC Delta Table to local as jsonl.""" + log.info(f'Start .... Convert delta to json') + + obj = urllib.parse.urlparse(args.json_output_folder) + if obj.scheme != '': + raise ValueError( + f'Check the json_output_folder and verify it is a local path!') + + if os.path.exists(args.json_output_folder): + if not os.path.isdir(args.json_output_folder) or os.listdir( + args.json_output_folder): + raise RuntimeError( + f'A file or a folder {args.json_output_folder} already exists and is not empty. Remove it and retry!' + ) + + os.makedirs(args.json_output_folder, exist_ok=True) + + if not args.json_output_filename.endswith('.jsonl'): + raise ValueError('json_output_filename needs to be a jsonl file') + + log.info(f'Directory {args.json_output_folder} created.') + + method, dbsql, sparkSession = validate_and_get_cluster_info( + cluster_id=args.cluster_id, + databricks_host=args.DATABRICKS_HOST, + databricks_token=args.DATABRICKS_TOKEN, + http_path=args.http_path, + use_serverless=args.use_serverless) fetch(method, args.delta_table_name, args.json_output_folder, args.batch_size, args.processes, sparkSession, dbsql) @@ -494,9 +523,8 @@ def fetch_DT(args: Namespace) -> None: help='number of processes allowed to use') parser.add_argument( '--cluster_id', - required=True, + required=False, type=str, - default=None, help= 'cluster id has runtime newer than 14.1.0 and access mode of either assigned or shared can use databricks-connect.' ) @@ -513,7 +541,9 @@ def fetch_DT(args: Namespace) -> None: required=False, type=str, default='train-00000-of-00001.jsonl', - help='The combined final jsonl that combines all partitioned jsonl') + help= + 'The name of the combined final jsonl that combines all partitioned jsonl' + ) args = parser.parse_args() from databricks.sdk import WorkspaceClient From b40b80519203ffb7ecdd638d8e0af52019d4689d Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Sun, 14 Jan 2024 11:17:14 -0800 Subject: [PATCH 18/31] fix mock (#872) --- tests/a_scripts/inference/test_convert_composer_to_hf.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 28fb9219f8..deed181475 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -300,6 +300,8 @@ def test_huggingface_conversion_callback_interval( mlflow_logger_mock.save_model = MagicMock() mlflow_logger_mock.register_model = MagicMock() mlflow_logger_mock.model_registry_prefix = '' + mlflow_logger_mock._experiment_id = 'mlflow-experiment-id' + mlflow_logger_mock._run_id = 'mlflow-run-id' trainer = Trainer( model=original_model, device='gpu', @@ -534,6 +536,8 @@ def test_huggingface_conversion_callback( mlflow_logger_mock.save_model = MagicMock() mlflow_logger_mock.register_model = MagicMock() mlflow_logger_mock.model_registry_prefix = '' + mlflow_logger_mock._experiment_id = 'mlflow-experiment-id' + mlflow_logger_mock._run_id = 'mlflow-run-id' trainer = Trainer( model=original_model, device='gpu', From f43d1cfb1ef8f38ca90fee68b0643f45d6d5b2da Mon Sep 17 00:00:00 2001 From: Nancy Hung Date: Tue, 16 Jan 2024 17:45:43 -0800 Subject: [PATCH 19/31] fix regex (#877) --- scripts/data_prep/convert_delta_to_json.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 326b8e912f..9d73cb0e9f 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -407,7 +407,7 @@ def validate_and_get_cluster_info(cluster_id: str, stripped_runtime = re.sub( r'[a-zA-Z]', '', res.spark_version.split('-scala')[0].replace('x-snapshot', '')) - runtime_version = re.sub(r'.-+$', '', stripped_runtime) + runtime_version = re.sub(r'[.-]*$', '', stripped_runtime) if version.parse(runtime_version) < version.parse( MINIMUM_SQ_CONNECT_DBR_VERSION): raise ValueError( From 6dcc0d84b9cc7e5ba3b3f8148cf8459e298e2ba1 Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Wed, 17 Jan 2024 18:24:04 -0800 Subject: [PATCH 20/31] Precompute flash attention padding info (#880) * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * Update llmfoundry/models/mpt/modeling_mpt.py Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> * dummy data * undoing last commit * .. * .. * Update llmfoundry/models/mpt/modeling_mpt.py Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> * .. * .. --------- Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> --- llmfoundry/models/layers/attention.py | 39 ++-- llmfoundry/models/layers/blocks.py | 4 +- llmfoundry/models/mpt/modeling_mpt.py | 89 +++++++-- tests/models/layers/test_flash_attn.py | 180 ++++++++++-------- .../models/layers/test_flash_triton_torch.py | 32 +++- tests/models/test_rope_dail_vs_hf.py | 13 +- 6 files changed, 232 insertions(+), 125 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 0fb6c0a042..fecd79553f 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -228,12 +228,17 @@ def flash_attn_fn( training: bool = False, needs_weights: bool = False, multiquery: bool = False, - attention_mask_in_length: Optional[torch.Tensor] = None, should_repeat_kv_for_gqa: Optional[bool] = True, sliding_window_size: int = -1, alibi_slopes: Optional[torch.Tensor] = None, + flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: + if key_padding_mask is not None: + raise ValueError('key_padding_mask should be None for flash attn.') + del key_padding_mask + if flash_attn_padding_info is None: + raise ValueError('flash_attn_padding_info is required for flash attn.') try: from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip except: @@ -267,25 +272,24 @@ def flash_attn_fn( batch_size, seqlen = query.shape[:2] - if attention_mask_in_length is None: - if key_padding_mask is None: - key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) - query_padding_mask = key_padding_mask[:, -query.size(1):] - unpadding_function = bert_padding.unpad_input - else: - key_padding_mask = attention_mask_in_length - query_padding_mask = attention_mask_in_length - unpadding_function = bert_padding.unpad_input_for_concatenated_sequences + indices_q = flash_attn_padding_info['indices_q'] + indices_k = flash_attn_padding_info['indices_k'] + indices_v = flash_attn_padding_info['indices_v'] + cu_seqlens_q = flash_attn_padding_info['cu_seqlens_q'] + cu_seqlens_k = flash_attn_padding_info['cu_seqlens_k'] + max_seqlen_q = flash_attn_padding_info['max_seqlen_q'] + max_seqlen_k = flash_attn_padding_info['max_seqlen_k'] - query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function( - query, query_padding_mask) + query_unpad = bert_padding.index_first_axis( + rearrange(query, 'b s ... -> (b s) ...'), indices_q) query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads) - key_unpad, _, cu_seqlens_k, max_seqlen_k = unpadding_function( - key, key_padding_mask) + key_unpad = bert_padding.index_first_axis( + rearrange(key, 'b s ... -> (b s) ...'), indices_k) key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) - value_unpad, _, _, _ = unpadding_function(value, key_padding_mask) + value_unpad = bert_padding.index_first_axis( + rearrange(value, 'b s ... -> (b s) ...'), indices_v) value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) if (kv_n_heads < n_heads) and (not is_flash_v2_installed()) and ( @@ -599,8 +603,8 @@ def forward( rotary_emb_w_meta_info: Optional[dict] = None, is_causal: bool = True, needs_weights: bool = False, - attention_mask_in_length: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: qkv = self.Wqkv(x) @@ -666,11 +670,12 @@ def forward( extra_attn_kwargs = {} if self.attn_impl == 'flash': + key_padding_mask = None extra_attn_kwargs = { - 'attention_mask_in_length': attention_mask_in_length, 'should_repeat_kv_for_gqa': not is_flash_v2_installed(), 'sliding_window_size': self.sliding_window_size, 'alibi_slopes': alibi_slopes, + 'flash_attn_padding_info': flash_attn_padding_info, } context, attn_weights, past_key_value = self.attn_fn( diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index e5032998dc..036a4e7cd2 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -122,8 +122,8 @@ def forward( attention_mask: Optional[torch.ByteTensor] = None, is_causal: bool = True, output_attentions: bool = False, - attention_mask_in_length: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ torch.Tensor, torch.Tensor]]]: a = self.norm_1(x) @@ -135,8 +135,8 @@ def forward( attention_mask=attention_mask, is_causal=is_causal, needs_weights=output_attentions, - attention_mask_in_length=attention_mask_in_length, alibi_slopes=alibi_slopes, + flash_attn_padding_info=flash_attn_padding_info, ) x = x + self.resid_attn_dropout(b) m = x diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 8b14c72f62..f49b1b88f8 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -24,15 +24,23 @@ from composer.models import HuggingFaceModel from composer.utils import dist -from llmfoundry.models.layers.attention import is_flash_v2_installed +from llmfoundry.models.layers.attention import (is_flash_v1_installed, + is_flash_v2_installed) if is_flash_v2_installed(): try: # This try...except is needed because transformers requires it despite the 'if' statement above + from flash_attn import bert_padding from flash_attn.layers.rotary import \ RotaryEmbedding as DAILRotaryEmbedding except Exception as e: raise e +if is_flash_v1_installed(): + try: # This try...except is needed because transformers requires it despite the 'if' statement above + from flash_attn import bert_padding + except Exception as e: + raise e + from omegaconf import DictConfig from omegaconf import OmegaConf as om from transformers import PreTrainedModel, PreTrainedTokenizerBase @@ -216,6 +224,44 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int, return attention_mask_in_length +def gen_flash_attn_padding_info( + bsz: int, + S: int, + past_key_len: int, + device: torch.device, + attention_mask_in_length: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None): + flash_attn_padding_info = {} + if attention_mask_in_length is None: + key_padding_mask = attention_mask + if key_padding_mask is None: + key_padding_mask = torch.ones((bsz, past_key_len + S), + dtype=torch.bool, + device=device) + query_padding_mask = key_padding_mask[:, -S:] + unpadding_function = bert_padding.unpad_input + else: + key_padding_mask = attention_mask_in_length + query_padding_mask = attention_mask_in_length + unpadding_function = bert_padding.unpad_input_for_concatenated_sequences + + _, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function( + torch.empty(bsz, S, 1, device=device), query_padding_mask) + _, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function( + torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask) + _, indices_v, _, _ = unpadding_function( + torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask) + + flash_attn_padding_info['indices_q'] = indices_q + flash_attn_padding_info['indices_k'] = indices_k + flash_attn_padding_info['indices_v'] = indices_v + flash_attn_padding_info['cu_seqlens_q'] = cu_seqlens_q + flash_attn_padding_info['cu_seqlens_k'] = cu_seqlens_k + flash_attn_padding_info['max_seqlen_q'] = max_seqlen_q + flash_attn_padding_info['max_seqlen_k'] = max_seqlen_k + return flash_attn_padding_info + + def apply_sequence_id(attn_bias: torch.Tensor, sequence_id: torch.LongTensor, max_seq_len: int) -> torch.Tensor: seq_len = sequence_id.shape[-1] @@ -515,10 +561,12 @@ def forward( raise ValueError( 'You cannot specify both input_ids and inputs_embeds.') elif input_ids is not None: + bsz = input_ids.size(0) S = input_ids.size(1) x = self.wte(input_ids) input_device = input_ids.device elif inputs_embeds is not None: + bsz = inputs_embeds.size(0) S = inputs_embeds.size(1) x = inputs_embeds input_device = inputs_embeds.device @@ -530,22 +578,23 @@ def forward( ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' rotary_emb_w_meta_info = None - if self.learned_pos_emb or self.rope: - past_position = 0 - if past_key_values is not None: - if len(past_key_values) != self.config.n_layers: - raise ValueError( - f'past_key_values must provide a past_key_value for each attention ' - + - f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).' - ) - # For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim). - # For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq). - # Here we shift position embedding using the `seq` dim of the past key - past_position = past_key_values[0][0].size(1) - if self.attn_impl == 'torch': - past_position = past_key_values[0][0].size(3) + past_position = 0 + if past_key_values is not None: + if len(past_key_values) != self.config.n_layers: + raise ValueError( + f'past_key_values must provide a past_key_value for each attention ' + + + f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).' + ) + # For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim). + # For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq). + # Here we shift position embedding using the `seq` dim of the past key + past_position = past_key_values[0][0].size(1) + if self.attn_impl == 'torch': + past_position = past_key_values[0][0].size(3) + + if self.learned_pos_emb or self.rope: if self.learned_pos_emb and (S + past_position > self.config.max_seq_len): raise ValueError( @@ -623,6 +672,12 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + flash_attn_padding_info = {} + if self.attn_impl == 'flash': + flash_attn_padding_info = gen_flash_attn_padding_info( + bsz, S, past_position, x.device, attention_mask_in_length, + attention_mask) + for b_idx, block in enumerate(self.blocks): if output_hidden_states: assert all_hidden_states is not None # pyright @@ -637,8 +692,8 @@ def forward( attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions), - attention_mask_in_length=attention_mask_in_length, alibi_slopes=alibi_slopes, + flash_attn_padding_info=flash_attn_padding_info, ) if presents is not None: presents += (present,) diff --git a/tests/models/layers/test_flash_attn.py b/tests/models/layers/test_flash_attn.py index 3e1ec37b2e..9471cdac6a 100644 --- a/tests/models/layers/test_flash_attn.py +++ b/tests/models/layers/test_flash_attn.py @@ -12,6 +12,7 @@ flash_attn_fn, gen_slopes, is_flash_v2_installed, triton_flash_attn_fn) +from llmfoundry.models.mpt.modeling_mpt import gen_flash_attn_padding_info @pytest.mark.gpu @@ -35,22 +36,24 @@ def test_gqa_kv_repetition(kv_n_heads: int): kv_n_heads * d).to(torch.bfloat16).cuda() value_1.requires_grad = True - output_1, _, _ = flash_attn_fn(query=query_1, - key=key_1, - value=value_1, - n_heads=n_heads, - kv_n_heads=kv_n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - attention_mask_in_length=None, - should_repeat_kv_for_gqa=True) + output_1, _, _ = flash_attn_fn( + query=query_1, + key=key_1, + value=value_1, + n_heads=n_heads, + kv_n_heads=kv_n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + flash_attn_padding_info=gen_flash_attn_padding_info( + bsz, seqlen_1, 0, query_1.device, None, None), + should_repeat_kv_for_gqa=True) output_1.sum().backward() @@ -61,22 +64,24 @@ def test_gqa_kv_repetition(kv_n_heads: int): value_2 = value_1.detach().clone() value_2.requires_grad = True - output_2, _, _ = flash_attn_fn(query=query_2, - key=key_2, - value=value_2, - n_heads=n_heads, - kv_n_heads=kv_n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - attention_mask_in_length=None, - should_repeat_kv_for_gqa=False) + output_2, _, _ = flash_attn_fn( + query=query_2, + key=key_2, + value=value_2, + n_heads=n_heads, + kv_n_heads=kv_n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + flash_attn_padding_info=gen_flash_attn_padding_info( + bsz, seqlen_1, 0, query_2.device, None, None), + should_repeat_kv_for_gqa=False) output_2.sum().backward() assert torch.allclose(output_1, output_2) @@ -114,6 +119,9 @@ def test_seq_id_masking_FA_v2(): [3, 2, 1, 0, 0, 0]]).to(torch.int64).cuda() + flash_attn_padding_info_1 = gen_flash_attn_padding_info( + bsz, seqlen_1, 0, query_1.device, attention_mask_in_length_1, None) + output_1, _, _ = flash_attn_fn( query=query_1, key=key_1, @@ -129,7 +137,7 @@ def test_seq_id_masking_FA_v2(): training=False, needs_weights=False, multiquery=False, - attention_mask_in_length=attention_mask_in_length_1) + flash_attn_padding_info=flash_attn_padding_info_1) output_1.sum().backward() @@ -141,21 +149,25 @@ def test_seq_id_masking_FA_v2(): value_2 = value_1.detach().clone()[:, seq_range[0]:seq_range[1], :] value_2.requires_grad = True - output_2, _, _ = flash_attn_fn(query=query_2, - key=key_2, - value=value_2, - n_heads=n_heads, - kv_n_heads=kv_n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - attention_mask_in_length=None) + flash_attn_padding_info_2 = gen_flash_attn_padding_info( + bsz, seq_range[1] - seq_range[0], 0, query_2.device, None, None) + + output_2, _, _ = flash_attn_fn( + query=query_2, + key=key_2, + value=value_2, + n_heads=n_heads, + kv_n_heads=kv_n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + flash_attn_padding_info=flash_attn_padding_info_2) output_2.sum().backward() assert torch.allclose(output_1[:, seq_range[0]:seq_range[1], :], @@ -196,23 +208,25 @@ def test_sliding_window(sliding_window_size: int): device=device) value_1.requires_grad = True - output_1, _, _ = flash_attn_fn(query=query_1, - key=key_1, - value=value_1, - n_heads=n_heads, - kv_n_heads=n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - attention_mask_in_length=None, - should_repeat_kv_for_gqa=True, - sliding_window_size=sliding_window_size) + output_1, _, _ = flash_attn_fn( + query=query_1, + key=key_1, + value=value_1, + n_heads=n_heads, + kv_n_heads=n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + flash_attn_padding_info=gen_flash_attn_padding_info( + bsz, seqlen_1, 0, query_1.device, None, None), + should_repeat_kv_for_gqa=True, + sliding_window_size=sliding_window_size) output_1.sum().backward() @@ -284,23 +298,25 @@ def test_alibi_bias(n_heads: int): alibi_bias_max=8, device=torch.device(device), return_1d=True) - output_1, _, _ = flash_attn_fn(query=query_1, - key=key_1, - value=value_1, - n_heads=n_heads, - kv_n_heads=n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - attention_mask_in_length=None, - should_repeat_kv_for_gqa=True, - alibi_slopes=alibi_slopes_1) + output_1, _, _ = flash_attn_fn( + query=query_1, + key=key_1, + value=value_1, + n_heads=n_heads, + kv_n_heads=n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + flash_attn_padding_info=gen_flash_attn_padding_info( + bsz, seqlen_1, 0, query_1.device, None, None), + should_repeat_kv_for_gqa=True, + alibi_slopes=alibi_slopes_1) output_1.sum().backward() diff --git a/tests/models/layers/test_flash_triton_torch.py b/tests/models/layers/test_flash_triton_torch.py index 4ca5c7b668..2f992cd92f 100644 --- a/tests/models/layers/test_flash_triton_torch.py +++ b/tests/models/layers/test_flash_triton_torch.py @@ -10,6 +10,7 @@ is_flash_v2_installed) from llmfoundry.models.mpt.modeling_mpt import (apply_sequence_id, gen_attention_mask_in_length, + gen_flash_attn_padding_info, gen_rotary_embedding) @@ -164,6 +165,13 @@ def gen_bias(attn_impl: str): attn_uses_sequence_id=attn_uses_sequence_id, attn_impl=attn_impl_0, attention_mask=attention_mask) + + flash_attn_padding_info_0 = {} + if attn_impl_0 == 'flash': + flash_attn_padding_info_0 = gen_flash_attn_padding_info( + n, s, 0, torch.device(device), attention_mask_in_length_0, + attention_mask) + attention_mask_in_length_1 = gen_attention_mask_in_length( sequence_id=sequence_id, S=s, @@ -171,6 +179,12 @@ def gen_bias(attn_impl: str): attn_impl=attn_impl_1, attention_mask=attention_mask) + flash_attn_padding_info_1 = {} + if attn_impl_1 == 'flash': + flash_attn_padding_info_1 = gen_flash_attn_padding_info( + n, s, 0, torch.device(device), attention_mask_in_length_1, + attention_mask) + x0 = torch.randn(n, s, f).to(device) x1 = x0.clone().detach() x0.requires_grad = True @@ -216,7 +230,7 @@ def gen_bias(attn_impl: str): attention_mask=attention_mask, rotary_emb_w_meta_info=rotary_emb_w_meta_info, is_causal=True, - attention_mask_in_length=attention_mask_in_length_0, + flash_attn_padding_info=flash_attn_padding_info_0, alibi_slopes=alibi_slopes_0) attn_bias_1 = gen_bias(attn_impl_1) alibi_slopes_1 = None @@ -231,7 +245,7 @@ def gen_bias(attn_impl: str): attention_mask=attention_mask, rotary_emb_w_meta_info=rotary_emb_w_meta_info, is_causal=True, - attention_mask_in_length=attention_mask_in_length_1, + flash_attn_padding_info=flash_attn_padding_info_1, alibi_slopes=alibi_slopes_1) y0 *= attention_mask.unsqueeze(-1) y1 *= attention_mask.unsqueeze(-1) @@ -313,11 +327,16 @@ def gen_tca_mask(): x1.requires_grad = True with torch.autocast(x0.device.type): + flash_attn_padding_info = None + if attn_impl == 'flash': + flash_attn_padding_info = gen_flash_attn_padding_info( + n, s, 0, torch.device(device), None, attention_mask) y0, _, _ = mmhsa(x0, past_key_value=None, attn_bias=None, attention_mask=attention_mask, - is_causal=True) + is_causal=True, + flash_attn_padding_info=flash_attn_padding_info) y1, _ = tmhsa(x1, x1, x1, @@ -387,11 +406,16 @@ def test_grouped_attention_heads(attn_impl: str, x0.requires_grad = True with torch.autocast(x0.device.type): + flash_attn_padding_info = None + if attn_impl == 'flash': + flash_attn_padding_info = gen_flash_attn_padding_info( + n, s, 0, torch.device(device), None, attention_mask) y0, _, _ = mmhsa(x0, past_key_value=None, attn_bias=None, attention_mask=attention_mask, - is_causal=True) + is_causal=True, + flash_attn_padding_info=flash_attn_padding_info) y0 *= attention_mask.unsqueeze(-1) loss0 = y0.sum() diff --git a/tests/models/test_rope_dail_vs_hf.py b/tests/models/test_rope_dail_vs_hf.py index 70a00470f9..33c3d3c052 100644 --- a/tests/models/test_rope_dail_vs_hf.py +++ b/tests/models/test_rope_dail_vs_hf.py @@ -7,7 +7,8 @@ from omegaconf import OmegaConf as om from llmfoundry.models.layers.attention import is_flash_v2_installed -from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding +from llmfoundry.models.mpt.modeling_mpt import (gen_flash_attn_padding_info, + gen_rotary_embedding) @pytest.mark.gpu @@ -104,14 +105,20 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): attn_bias=None, attention_mask=attention_mask, rotary_emb_w_meta_info=dail_rope_w_meta_info, - is_causal=True) + is_causal=True, + flash_attn_padding_info=gen_flash_attn_padding_info( + batch_size, seq_len, 0, torch.device(device), None, + attention_mask)) y1, _, _ = attn1(x1, past_key_value=None, attn_bias=None, attention_mask=attention_mask, rotary_emb_w_meta_info=hf_rope_w_meta_info, - is_causal=True) + is_causal=True, + flash_attn_padding_info=gen_flash_attn_padding_info( + batch_size, seq_len, 0, torch.device(device), None, + attention_mask)) y0 *= attention_mask.unsqueeze(-1) y1 *= attention_mask.unsqueeze(-1) From 2e4f4b2f6659fd95d055a865e4c09247ed99566f Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Wed, 17 Jan 2024 21:47:47 -0500 Subject: [PATCH 21/31] add missing import (#882) --- llmfoundry/models/layers/__init__.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/llmfoundry/models/layers/__init__.py b/llmfoundry/models/layers/__init__.py index 68aa0fe7fe..05350b059b 100644 --- a/llmfoundry/models/layers/__init__.py +++ b/llmfoundry/models/layers/__init__.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 from llmfoundry.models.layers.attention import ( - ATTN_CLASS_REGISTRY, MultiheadAttention, MultiQueryAttention, - attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn, - scaled_multihead_dot_product_attention, triton_flash_attn_fn) + ATTN_CLASS_REGISTRY, GroupedQueryAttention, MultiheadAttention, + MultiQueryAttention, attn_bias_shape, build_alibi_bias, build_attn_bias, + flash_attn_fn, scaled_multihead_dot_product_attention, triton_flash_attn_fn) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY @@ -17,6 +17,7 @@ 'triton_flash_attn_fn', 'MultiheadAttention', 'MultiQueryAttention', + 'GroupedQueryAttention', 'attn_bias_shape', 'build_attn_bias', 'build_alibi_bias', From 19ee086976a767de054d184fa50d3e325587975f Mon Sep 17 00:00:00 2001 From: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> Date: Thu, 18 Jan 2024 12:44:46 -0800 Subject: [PATCH 22/31] fsdp wrap refac (#883) * fsdp wrap refac * refac * refac --- llmfoundry/models/mpt/modeling_mpt.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index f49b1b88f8..2177124740 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -6,6 +6,8 @@ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py """ +from __future__ import annotations + import math import warnings from typing import (Any, Dict, List, Mapping, MutableMapping, Optional, Tuple, @@ -292,6 +294,14 @@ class MPTPreTrainedModel(PreTrainedModel): _no_split_modules = ['MPTBlock'] +def _fsdp_wrap_fn( + self: Union[MPTModel, MPTForCausalLM], + module: nn.Module, +) -> bool: + # FSDP Wrap function for MPT Models + return isinstance(module, MPTBlock) + + class MPTModel(MPTPreTrainedModel): def __init__(self, config: MPTConfig): @@ -728,7 +738,7 @@ def param_init_fn(self, module: nn.Module) -> None: # FSDP Wrap function def fsdp_wrap_fn(self, module: nn.Module) -> bool: - return isinstance(module, MPTBlock) + return _fsdp_wrap_fn(self, module) # Activation Checkpointing def activation_checkpointing_fn(self, module: nn.Module) -> bool: @@ -889,7 +899,7 @@ def param_init_fn(self, module: nn.Module) -> None: # FSDP Wrap function def fsdp_wrap_fn(self, module: nn.Module) -> bool: - return isinstance(module, MPTBlock) + return _fsdp_wrap_fn(self, module) # Activation Checkpointing def activation_checkpointing_fn(self, module: nn.Module) -> bool: From 35bb3390051903656b9aff504426da7524501e56 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Thu, 18 Jan 2024 13:58:10 -0800 Subject: [PATCH 23/31] Update model download utils to support ORAS (#881) * wip * wip * Accept registry file for hostname * Make sure no sensitive info is surfaced in subprocess error * Refactor model downloading * Save HF hub files to local dir * fallback * Remove commented code * Update logging * Update HTP download args * Use files for ORAS * Update llmfoundry/utils/model_download_utils.py Co-authored-by: Irene Dea --------- Co-authored-by: Irene Dea --- llmfoundry/utils/__init__.py | 4 +- llmfoundry/utils/model_download_utils.py | 144 +++++++++++++++-------- scripts/misc/download_hf_model.py | 83 ------------- scripts/misc/download_model.py | 115 ++++++++++++++++++ tests/utils/test_model_download_utils.py | 71 ++++++----- 5 files changed, 249 insertions(+), 168 deletions(-) delete mode 100644 scripts/misc/download_hf_model.py create mode 100644 scripts/misc/download_model.py diff --git a/llmfoundry/utils/__init__.py b/llmfoundry/utils/__init__.py index 7abe4dcf75..83af2153a2 100644 --- a/llmfoundry/utils/__init__.py +++ b/llmfoundry/utils/__init__.py @@ -12,7 +12,7 @@ log_config, pop_config, update_batch_size_info) from llmfoundry.utils.model_download_utils import ( - download_from_cache_server, download_from_hf_hub) + download_from_hf_hub, download_from_http_fileserver) except ImportError as e: raise ImportError( 'Please make sure to pip install . to get requirements for llm-foundry.' @@ -28,7 +28,7 @@ 'build_tokenizer', 'calculate_batch_size_info', 'convert_and_save_ft_weights', - 'download_from_cache_server', + 'download_from_http_fileserver', 'download_from_hf_hub', 'get_hf_tokenizer_from_composer_state_dict', 'update_batch_size_info', diff --git a/llmfoundry/utils/model_download_utils.py b/llmfoundry/utils/model_download_utils.py index 2104455e0f..5d8a413d91 100644 --- a/llmfoundry/utils/model_download_utils.py +++ b/llmfoundry/utils/model_download_utils.py @@ -5,6 +5,8 @@ import copy import logging import os +import shutil +import subprocess import time import warnings from http import HTTPStatus @@ -14,6 +16,7 @@ import huggingface_hub as hf_hub import requests import tenacity +import yaml from bs4 import BeautifulSoup from requests.packages.urllib3.exceptions import InsecureRequestWarning from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME @@ -28,6 +31,9 @@ PYTORCH_WEIGHTS_PATTERN = 'pytorch_model*.bin*' SAFE_WEIGHTS_PATTERN = 'model*.safetensors*' +ORAS_PASSWD_PLACEHOLDER = '' +ORAS_CLI = 'oras' + log = logging.getLogger(__name__) @@ -36,8 +42,8 @@ stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(min=1, max=10)) def download_from_hf_hub( - repo_id: str, - save_dir: Optional[str] = None, + model: str, + save_dir: str, prefer_safetensors: bool = True, token: Optional[str] = None, ): @@ -48,8 +54,7 @@ def download_from_hf_hub( Args: repo_id (str): The Hugging Face Hub repo ID. - save_dir (str, optional): The path to the directory where the model files will be downloaded. If `None`, reads - from the `HUGGINGFACE_HUB_CACHE` environment variable or uses the default Hugging Face Hub cache directory. + save_dir (str, optional): The local path to the directory where the model files will be downloaded. prefer_safetensors (bool): Whether to prefer Safetensors weights over PyTorch weights if both are available. Defaults to True. token (str, optional): The HuggingFace API token. If not provided, the token will be read from the @@ -59,7 +64,7 @@ def download_from_hf_hub( RepositoryNotFoundError: If the model repo doesn't exist or the token is unauthorized. ValueError: If the model repo doesn't contain any supported model weights. """ - repo_files = set(hf_hub.list_repo_files(repo_id)) + repo_files = set(hf_hub.list_repo_files(model)) # Ignore TensorFlow, TensorFlow 2, and Flax weights as they are not supported by Composer. ignore_patterns = copy.deepcopy(DEFAULT_IGNORE_PATTERNS) @@ -86,18 +91,18 @@ def download_from_hf_hub( log.info('Only pytorch available. Ignoring weights preference.') else: raise ValueError( - f'No supported model weights found in repo {repo_id}.' + + f'No supported model weights found in repo {model}.' + ' Please make sure the repo contains either safetensors or pytorch weights.' ) download_start = time.time() - hf_hub.snapshot_download(repo_id, - cache_dir=save_dir, + hf_hub.snapshot_download(model, + local_dir=save_dir, ignore_patterns=ignore_patterns, token=token) download_duration = time.time() - download_start log.info( - f'Downloaded model {repo_id} from Hugging Face Hub in {download_duration} seconds' + f'Downloaded model {model} from Hugging Face Hub in {download_duration} seconds' ) @@ -140,6 +145,7 @@ def _recursive_download( RuntimeError: If the remote server returns a status code other than 200 OK or 401 Unauthorized. """ url = urljoin(base_url, path) + print(url) response = session.get(url, verify=(not ignore_cert)) if response.status_code == HTTPStatus.UNAUTHORIZED: @@ -156,7 +162,7 @@ def _recursive_download( ) # Assume that the URL points to a file if it does not end with a slash. - if not path.endswith('/'): + if not url.endswith('/'): save_path = os.path.join(save_dir, path) parent_dir = os.path.dirname(save_path) if not os.path.exists(parent_dir): @@ -171,6 +177,7 @@ def _recursive_download( # If the URL is a directory, the response should be an HTML directory listing that we can parse for additional links # to download. child_links = _extract_links_from_html(response.content.decode()) + print(child_links) for child_link in child_links: _recursive_download(session, base_url, @@ -183,53 +190,98 @@ def _recursive_download( (PermissionError, ValueError)), stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(min=1, max=10)) -def download_from_cache_server( - model_name: str, - cache_base_url: str, +def download_from_http_fileserver( + url: str, save_dir: str, - token: Optional[str] = None, ignore_cert: bool = False, ): - """Downloads Hugging Face models from a mirror file server. - - The file server is expected to store the files in the same structure as the Hugging Face cache - structure. See https://huggingface.co/docs/huggingface_hub/guides/manage-cache. + """Downloads files from a remote HTTP file server. Args: - model_name: The name of the model to download. This should be the same as the repository ID in the Hugging Face - Hub. - cache_base_url: The base URL of the cache file server. This function will attempt to download all of the blob - files from `//blobs/`, where `formatted_model_name` is equal to - `models/` with all slashes replaced with `--`. - save_dir: The directory to save the downloaded files to. - token: The Hugging Face API token. If not provided, the token will be read from the `HUGGING_FACE_HUB_TOKEN` - environment variable. - ignore_cert: Whether or not to ignore the validity of the SSL certificate of the remote server. Defaults to - False. + url (str): The base URL where the files are located. + save_dir (str): The directory to save downloaded files to. + ignore_cert (bool): Whether or not to ignore the validity of the SSL certificate of the remote server. + Defaults to False. WARNING: Setting this to true is *not* secure, as no certificate verification will be performed. """ - formatted_model_name = f'models/{model_name}'.replace('/', '--') with requests.Session() as session: - session.headers.update({'Authorization': f'Bearer {token}'}) - - download_start = time.time() - # Temporarily suppress noisy SSL certificate verification warnings if ignore_cert is set to True with warnings.catch_warnings(): if ignore_cert: warnings.simplefilter('ignore', category=InsecureRequestWarning) - # Only downloads the blobs in order to avoid downloading model files twice due to the - # symlnks in the Hugging Face cache structure: - _recursive_download( - session, - cache_base_url, - # Trailing slash to indicate directory - f'{formatted_model_name}/blobs/', - save_dir, - ignore_cert=ignore_cert, - ) - download_duration = time.time() - download_start - log.info( - f'Downloaded model {model_name} from cache server in {download_duration} seconds' + _recursive_download(session, + url, + '', + save_dir, + ignore_cert=ignore_cert) + + +def download_from_oras(model: str, + config_file: str, + credentials_dir: str, + save_dir: str, + concurrency: int = 10): + """Download from an OCI-compliant registry using oras. + + Args: + model: The name of the model to download. + config_file: Path to a YAML config file that maps model names to registry paths. + credentials_dir: Path to a directory containing credentials for the registry. It is expected to contain three + files: `username`, `password`, and `registry`, each of which contains the corresponding credential. + save_dir: Path to the directory where files will be downloaded. + concurrency: The number of concurrent downloads to run. + """ + if shutil.which(ORAS_CLI) is None: + raise Exception( + f'oras cli command `{ORAS_CLI}` is not found. Please install oras: https://oras.land/docs/installation ' ) + + def _read_secrets_file(secret_file_path: str,): + try: + with open(secret_file_path, encoding='utf-8') as f: + return f.read().strip() + except Exception as error: + raise ValueError( + f'secrets file {secret_file_path} failed to be read') from error + + secrets = {} + for secret in ['username', 'password', 'registry']: + secrets[secret] = _read_secrets_file( + os.path.join(credentials_dir, secret)) + + with open(config_file, 'r', encoding='utf-8') as f: + configs = yaml.safe_load(f.read()) + + path = configs['models'][model] + registry = secrets['registry'] + + def get_oras_cmd(username: Optional[str] = None, + password: Optional[str] = None): + cmd = [ + ORAS_CLI, + 'pull', + f'{registry}/{path}', + '-o', + save_dir, + '--verbose', + '--concurrency', + str(concurrency), + ] + if username is not None: + cmd.extend(['--username', username]) + if password is not None: + cmd.extend(['--password', password]) + + return cmd + + cmd_without_creds = get_oras_cmd() + log.info(f'CMD for oras cli to run: {" ".join(cmd_without_creds)}') + cmd_to_run = get_oras_cmd(username=secrets['username'], + password=secrets['password']) + try: + subprocess.run(cmd_to_run, check=True) + except subprocess.CalledProcessError as e: + # Intercept the error and replace the cmd, which may have sensitive info. + raise subprocess.CalledProcessError(e.returncode, cmd_without_creds, + e.output, e.stderr) diff --git a/scripts/misc/download_hf_model.py b/scripts/misc/download_hf_model.py deleted file mode 100644 index 58c3445e7d..0000000000 --- a/scripts/misc/download_hf_model.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright 2022 MosaicML LLM Foundry authors -# SPDX-License-Identifier: Apache-2.0 - -"""Script to download model weights from Hugging Face Hub or a cache server.""" -import argparse -import logging -import os -import sys - -from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE - -from llmfoundry.utils.model_download_utils import (download_from_cache_server, - download_from_hf_hub) - -HF_TOKEN_ENV_VAR = 'HUGGING_FACE_HUB_TOKEN' - -logging.basicConfig(format=f'%(asctime)s: %(levelname)s: %(name)s: %(message)s', - level=logging.INFO) -log = logging.getLogger(__name__) - -if __name__ == '__main__': - argparser = argparse.ArgumentParser() - argparser.add_argument('--model', type=str, required=True) - argparser.add_argument('--download-from', - type=str, - choices=['hf', 'cache'], - default='hf') - argparser.add_argument('--token', - type=str, - default=os.getenv(HF_TOKEN_ENV_VAR)) - argparser.add_argument('--save-dir', - type=str, - default=HUGGINGFACE_HUB_CACHE) - argparser.add_argument('--cache-url', type=str, default=None) - argparser.add_argument('--ignore-cert', action='store_true', default=False) - argparser.add_argument( - '--fallback', - action='store_true', - default=True, - help= - 'Whether to fallback to downloading from Hugging Face if download from cache fails', - ) - - args = argparser.parse_args(sys.argv[1:]) - if args.download_from == 'hf': - download_from_hf_hub(args.model, - save_dir=args.save_dir, - token=args.token) - else: - try: - download_from_cache_server( - args.model, - args.cache_url, - args.save_dir, - token=args.token, - ignore_cert=args.ignore_cert, - ) - - # A little hacky: run the Hugging Face download just to repair the symlinks in the HF cache file structure. - # This shouldn't actually download any files if the cache server download was successful, but should address - # a non-deterministic bug where the symlinks aren't repaired properly by the time the model is initialized. - log.info('Repairing Hugging Face cache symlinks') - - # Hide some noisy logs that aren't important for just the symlink repair. - old_level = logging.getLogger().level - logging.getLogger().setLevel(logging.ERROR) - download_from_hf_hub(args.model, - save_dir=args.save_dir, - token=args.token) - logging.getLogger().setLevel(old_level) - - except PermissionError: - log.error(f'Not authorized to download {args.model}.') - except Exception as e: - if args.fallback: - log.warning( - f'Failed to download {args.model} from cache server. Falling back to Hugging Face Hub. Error: {e}' - ) - download_from_hf_hub(args.model, - save_dir=args.save_dir, - token=args.token) - else: - raise e diff --git a/scripts/misc/download_model.py b/scripts/misc/download_model.py new file mode 100644 index 0000000000..1913267e20 --- /dev/null +++ b/scripts/misc/download_model.py @@ -0,0 +1,115 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Script to download model weights from Hugging Face Hub or a cache server. + +Download from Hugging Face Hub: + python download_model.py hf --model mosaicml/mpt-7b --save-dir --token + +Download from ORAS registry: + python download_model.py oras --registry --path mosaicml/mpt-7b --save-dir + +Download from an HTTP file server: + python download_model.py http --host https://server.com --path mosaicml/mpt-7b --save-dir + +Download from an HTTP file server with fallback to Hugging Face Hub: + python download_model.py http --host https://server.com --path mosaicml/mpt-7b --save-dir \ + fallback-hf --model mosaicml/mpt-7b --token hf_token +""" +import argparse +import logging +import os + +from llmfoundry.utils.model_download_utils import ( + download_from_hf_hub, download_from_http_fileserver, download_from_oras) + +HF_TOKEN_ENV_VAR = 'HUGGING_FACE_HUB_TOKEN' + +logging.basicConfig(format=f'%(asctime)s: %(levelname)s: %(name)s: %(message)s', + level=logging.INFO) +log = logging.getLogger(__name__) + + +def add_hf_parser_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument('--model', type=str, required=True) + parser.add_argument('--prefer-safetensors', type=bool, default=True) + parser.add_argument('--token', + type=str, + default=os.getenv(HF_TOKEN_ENV_VAR)) + + +def add_oras_parser_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument('--model', type=str, required=True) + parser.add_argument('--config-file', type=str, required=True) + parser.add_argument('--credentials-dir', type=str, required=True) + parser.add_argument('--concurrency', type=int, default=10) + + +def add_http_parser_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument('--url', type=str, required=True) + parser.add_argument('--ignore-cert', action='store_true', default=False) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest='download_from', required=True) + + base_parser = argparse.ArgumentParser(add_help=False) + base_parser.add_argument('--save-dir', type=str, required=True) + + # Add subparser for downloading from Hugging Face Hub. + hf_parser = subparsers.add_parser('hf', parents=[base_parser]) + add_hf_parser_arguments(hf_parser) + + # Add subparser for downloading from ORAS registry. + oras_parser = subparsers.add_parser('oras', parents=[base_parser]) + add_oras_parser_arguments(oras_parser) + + # Add subparser for downloading from an HTTP file server. + http_parser = subparsers.add_parser('http', parents=[base_parser]) + add_http_parser_arguments(http_parser) + + # Add fallbacks for HTTP + fallback_subparsers = http_parser.add_subparsers(dest='fallback') + hf_fallback_parser = fallback_subparsers.add_parser('fallback-hf') + add_hf_parser_arguments(hf_fallback_parser) + + oras_fallback_parser = fallback_subparsers.add_parser('fallback-oras') + add_oras_parser_arguments(oras_fallback_parser) + + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_args() + download_from = args.download_from + + if download_from == 'http': + try: + download_from_http_fileserver(args.url, args.save_dir, + args.ignore_cert) + except PermissionError as e: + log.error(f'Not authorized to download {args.model}.') + raise e + except Exception as e: + log.warning(f'Failed to download from HTTP server with error: {e}') + if args.fallback: + log.warning(f'Falling back to provided fallback destination.') + if args.fallback == 'fallback-hf': + download_from = 'hf' + elif args.fallback == 'fallback-oras': + download_from = 'oras' + else: + raise ValueError( + f'Invalid fallback destination {args.fallback}.') + else: + raise e + + if download_from == 'hf': + download_from_hf_hub(args.model, + save_dir=args.save_dir, + token=args.token, + prefer_safetensors=args.prefer_safetensors) + elif download_from == 'oras': + download_from_oras(args.model, args.config_file, args.credentials_dir, + args.save_dir, args.concurrency) diff --git a/tests/utils/test_model_download_utils.py b/tests/utils/test_model_download_utils.py index 27b9805cda..471a39dcdb 100644 --- a/tests/utils/test_model_download_utils.py +++ b/tests/utils/test_model_download_utils.py @@ -16,11 +16,9 @@ from transformers.utils import WEIGHTS_INDEX_NAME as PYTORCH_WEIGHTS_INDEX_NAME from transformers.utils import WEIGHTS_NAME as PYTORCH_WEIGHTS_NAME -from llmfoundry.utils.model_download_utils import (DEFAULT_IGNORE_PATTERNS, - PYTORCH_WEIGHTS_PATTERN, - SAFE_WEIGHTS_PATTERN, - download_from_cache_server, - download_from_hf_hub) +from llmfoundry.utils.model_download_utils import ( + DEFAULT_IGNORE_PATTERNS, PYTORCH_WEIGHTS_PATTERN, SAFE_WEIGHTS_PATTERN, + download_from_hf_hub, download_from_http_fileserver) # ======================== download_from_hf_hub tests ======================== @@ -103,15 +101,17 @@ def test_download_from_hf_hub_weights_pref(mock_list_repo_files: MagicMock, repo_files: List[str], expected_ignore_patterns: List[str]): test_repo_id = 'test_repo_id' + save_dir = 'save_dir' mock_list_repo_files.return_value = repo_files - download_from_hf_hub(test_repo_id, prefer_safetensors=prefer_safetensors) + download_from_hf_hub(test_repo_id, + save_dir=save_dir, + prefer_safetensors=prefer_safetensors) mock_snapshot_download.assert_called_once_with( test_repo_id, - cache_dir=None, + local_dir=save_dir, ignore_patterns=expected_ignore_patterns, - token=None, - ) + token=None) @mock.patch('huggingface_hub.snapshot_download') @@ -121,10 +121,11 @@ def test_download_from_hf_hub_no_weights( mock_snapshot_download: MagicMock, ): test_repo_id = 'test_repo_id' + save_dir = 'save_dir' mock_list_repo_files.return_value = [] with pytest.raises(ValueError): - download_from_hf_hub(test_repo_id) + download_from_hf_hub(test_repo_id, save_dir) mock_snapshot_download.assert_not_called() @@ -148,12 +149,12 @@ def test_download_from_hf_hub_retry( mock_snapshot_download.side_effect = exception with pytest.raises((tenacity.RetryError, exception.__class__)): - download_from_hf_hub('test_repo_id') + download_from_hf_hub('test_repo_id', 'save_dir') assert mock_snapshot_download.call_count == expected_attempts -# ======================== download_from_cache_server tests ======================== +# ======================== download_from_http_fileserver tests ======================== ROOT_HTML = b""" @@ -182,51 +183,47 @@ def test_download_from_hf_hub_retry( @mock.patch.object(requests.Session, 'get') @mock.patch('os.makedirs') @mock.patch('builtins.open') -def test_download_from_cache_server(mock_open: MagicMock, - mock_makedirs: MagicMock, - mock_get: MagicMock): - cache_url = 'https://cache.com/' - model_name = 'model' - formatted_model_name = 'models--model' +def test_download_from_http_fileserver(mock_open: MagicMock, + mock_makedirs: MagicMock, + mock_get: MagicMock): + model_url = f'https://cache.com/models/model/' save_dir = 'save_dir/' mock_open.return_value = MagicMock() def _server_response(url: str, **kwargs: Dict[str, Any]): - if url == urljoin(cache_url, f'{formatted_model_name}/blobs/'): + if url == model_url: return MagicMock(status_code=HTTPStatus.OK, content=ROOT_HTML) - if url == urljoin(cache_url, f'{formatted_model_name}/blobs/file1'): + if url == urljoin(model_url, 'file1'): return MagicMock(status_code=HTTPStatus.OK) - elif url == urljoin(cache_url, f'{formatted_model_name}/blobs/folder/'): + elif url == urljoin(model_url, 'folder/'): return MagicMock(status_code=HTTPStatus.OK, content=SUBFOLDER_HTML) - elif url == urljoin(cache_url, - f'{formatted_model_name}/blobs/folder/file2'): + elif url == urljoin(model_url, 'folder/file2'): return MagicMock(status_code=HTTPStatus.OK) else: return MagicMock(status_code=HTTPStatus.NOT_FOUND) mock_get.side_effect = _server_response - download_from_cache_server(model_name, cache_url, 'save_dir/') + download_from_http_fileserver(model_url, save_dir) - mock_open.assert_has_calls([ - mock.call(os.path.join(save_dir, formatted_model_name, 'blobs/file1'), - 'wb'), - mock.call( - os.path.join(save_dir, formatted_model_name, 'blobs/folder/file2'), - 'wb'), - ], - any_order=True) + mock_open.assert_has_calls( + [ + mock.call(os.path.join(save_dir, 'file1'), 'wb'), + mock.call(os.path.join(save_dir, 'folder/file2'), 'wb'), + ], + any_order=True, + ) @mock.patch.object(requests.Session, 'get') -def test_download_from_cache_server_unauthorized(mock_get: MagicMock): - cache_url = 'https://cache.com/' +def test_download_from_http_fileserver_unauthorized(mock_get: MagicMock): model_name = 'model' + cache_url = f'https://cache.com/models--{model_name}/blobs/' save_dir = 'save_dir/' mock_get.return_value = MagicMock(status_code=HTTPStatus.UNAUTHORIZED) with pytest.raises(PermissionError): - download_from_cache_server(model_name, cache_url, save_dir) + download_from_http_fileserver(cache_url, save_dir) @pytest.mark.parametrize(['exception', 'expected_attempts'], [ @@ -236,7 +233,7 @@ def test_download_from_cache_server_unauthorized(mock_get: MagicMock): ]) @mock.patch('tenacity.nap.time.sleep') @mock.patch('llmfoundry.utils.model_download_utils._recursive_download') -def test_download_from_cache_server_retry( +def test_download_from_http_fileserver_retry( mock_recursive_download: MagicMock, mock_sleep: MagicMock, # so the retry wait doesn't actually wait exception: BaseException, @@ -245,4 +242,4 @@ def test_download_from_cache_server_retry( mock_recursive_download.side_effect = exception with pytest.raises((tenacity.RetryError, exception.__class__)): - download_from_cache_server('model', 'cache_url', 'save_dir') + download_from_http_fileserver('cache_url', 'save_dir') From 19368c66e9b01d32a97d81548ebbb42b107c4970 Mon Sep 17 00:00:00 2001 From: Brian <23239305+b-chu@users.noreply.github.com> Date: Fri, 19 Jan 2024 15:34:18 -0500 Subject: [PATCH 24/31] Update license (#887) Updates the license for 2024. New files will have a copyright year of 2024 inserted in the header. Existing files will not be changed. --- .ci/FILE_HEADER | 2 +- .pre-commit-config.yaml | 3 ++- setup.py | 3 +++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.ci/FILE_HEADER b/.ci/FILE_HEADER index 22198520fd..e6d99f5d6f 100644 --- a/.ci/FILE_HEADER +++ b/.ci/FILE_HEADER @@ -1,2 +1,2 @@ -Copyright 2022 MosaicML LLM Foundry authors +Copyright 2024 MosaicML LLM Foundry authors SPDX-License-Identifier: Apache-2.0 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d4c8cc699c..4a5bf1f6bb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -57,7 +57,7 @@ repos: - id: mixed-line-ending - id: trailing-whitespace - repo: https://github.com/Lucas-C/pre-commit-hooks - rev: v1.3.1 + rev: v1.5.4 hooks: - id: insert-license args: @@ -65,6 +65,7 @@ repos: - .ci/FILE_HEADER - --comment-style - '#' + - --allow-past-years types: [python] - repo: https://github.com/PyCQA/docformatter rev: v1.5.0 diff --git a/setup.py b/setup.py index 2c4a05f396..2a2b5cc9cd 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + """MosaicML LLM Foundry package setup.""" import os From c9a49d0b3db97263f2df2d79f51a386e3e2b69a7 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Fri, 19 Jan 2024 18:45:22 -0800 Subject: [PATCH 25/31] Fix tiktoken add generation prompt (#890) --- llmfoundry/tokenizers/tiktoken.py | 2 +- tests/tokenizers/test_tiktoken.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/llmfoundry/tokenizers/tiktoken.py b/llmfoundry/tokenizers/tiktoken.py index 2632985533..eaaf0da316 100644 --- a/llmfoundry/tokenizers/tiktoken.py +++ b/llmfoundry/tokenizers/tiktoken.py @@ -198,7 +198,7 @@ def default_chat_template(self): '{% else %}' "{{ '\n' + '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' }}" '{% endif %}' - '{% if (add_generation_prompt == true) %}' + '{% if (add_generation_prompt == true and loop.last) %}' "{{ '\n' + '<|im_start|>' + 'assistant' + '\n' }}" "{% elif (message['role'] == 'assistant') %}" '{{ eos_token }}' diff --git a/tests/tokenizers/test_tiktoken.py b/tests/tokenizers/test_tiktoken.py index 6a4d1c99c4..aca269af82 100644 --- a/tests/tokenizers/test_tiktoken.py +++ b/tests/tokenizers/test_tiktoken.py @@ -108,6 +108,12 @@ 'Please summarize the goals in this text:\n\nGoing outside has benefits include reducing stress and triggering the relaxation response, which can help us not only feel better mentally, but even heal faster from physical ailments.', 'role': 'user' +}, { + 'content': 'You should go outside and touch grass.', + 'role': 'assistant' +}, { + 'content': 'What else can I do?', + 'role': 'user' }]] MULTI_TURN_GENERATE_STRING = [ @@ -118,6 +124,10 @@ Going outside has benefits include reducing stress and triggering the relaxation response, which can help us not only feel better mentally, but even heal faster from physical ailments.<|im_end|> <|im_start|>assistant +You should go outside and touch grass.<|im_end|><|endoftext|> +<|im_start|>user +What else can I do?<|im_end|> +<|im_start|>assistant """ ] From b2a0c03bc2684912f3fde0dbb798e62dc70607fd Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Sun, 21 Jan 2024 17:34:52 -0800 Subject: [PATCH 26/31] Upgrade Datasets version (#892) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 2a2b5cc9cd..8c43a309c3 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ 'transformers>=4.36,<4.37', 'mosaicml-streaming>=0.7.2,<0.8', 'torch>=2.1,<2.1.1', - 'datasets==2.15.0', + 'datasets>=2.16,<2.17', 'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data 'sentencepiece==0.1.97', 'einops==0.7.0', From 02c44ad6837a5aa61f9268f322b8ca19b8917dde Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Mon, 22 Jan 2024 17:43:58 -0800 Subject: [PATCH 27/31] Bump transformers version to support Mixtral (#894) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 8c43a309c3..511e665ed4 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ install_requires = [ 'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.17.2,<0.18', 'accelerate>=0.25,<0.26', # for HF inference `device_map` - 'transformers>=4.36,<4.37', + 'transformers>=4.37,<4.38', 'mosaicml-streaming>=0.7.2,<0.8', 'torch>=2.1,<2.1.1', 'datasets>=2.16,<2.17', From 07d6db36b5399cec77dba4d9aa31a153105b348b Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Mon, 22 Jan 2024 19:16:08 -0800 Subject: [PATCH 28/31] Add `tokenizer-only` flag to only download tokenizers from HF or oras (#895) --- llmfoundry/utils/model_download_utils.py | 26 ++++++++++++++++++------ scripts/misc/download_model.py | 20 ++++++++++++++---- tests/utils/test_model_download_utils.py | 1 + 3 files changed, 37 insertions(+), 10 deletions(-) diff --git a/llmfoundry/utils/model_download_utils.py b/llmfoundry/utils/model_download_utils.py index 5d8a413d91..07c84a85c8 100644 --- a/llmfoundry/utils/model_download_utils.py +++ b/llmfoundry/utils/model_download_utils.py @@ -30,6 +30,12 @@ ] PYTORCH_WEIGHTS_PATTERN = 'pytorch_model*.bin*' SAFE_WEIGHTS_PATTERN = 'model*.safetensors*' +TOKENIZER_FILES = [ + 'special_tokens_map.json', + 'tokenizer.json', + 'tokenizer.model', + 'tokenizer_config.json', +] ORAS_PASSWD_PLACEHOLDER = '' ORAS_CLI = 'oras' @@ -45,6 +51,7 @@ def download_from_hf_hub( model: str, save_dir: str, prefer_safetensors: bool = True, + tokenizer_only: bool = False, token: Optional[str] = None, ): """Downloads model files from a Hugging Face Hub model repo. @@ -57,6 +64,7 @@ def download_from_hf_hub( save_dir (str, optional): The local path to the directory where the model files will be downloaded. prefer_safetensors (bool): Whether to prefer Safetensors weights over PyTorch weights if both are available. Defaults to True. + tokenizer_only (bool): If true, only download tokenizer files. token (str, optional): The HuggingFace API token. If not provided, the token will be read from the `HUGGING_FACE_HUB_TOKEN` environment variable. @@ -95,10 +103,13 @@ def download_from_hf_hub( ' Please make sure the repo contains either safetensors or pytorch weights.' ) + allow_patterns = TOKENIZER_FILES if tokenizer_only else None + download_start = time.time() hf_hub.snapshot_download(model, local_dir=save_dir, ignore_patterns=ignore_patterns, + allow_patterns=allow_patterns, token=token) download_duration = time.time() - download_start log.info( @@ -221,16 +232,18 @@ def download_from_oras(model: str, config_file: str, credentials_dir: str, save_dir: str, + tokenizer_only: bool = False, concurrency: int = 10): """Download from an OCI-compliant registry using oras. Args: - model: The name of the model to download. - config_file: Path to a YAML config file that maps model names to registry paths. - credentials_dir: Path to a directory containing credentials for the registry. It is expected to contain three + model (str): The name of the model to download. + config_file (str): Path to a YAML config file that maps model and tokenizer names to registry paths. + credentials_dir (str): Path to a directory containing credentials for the registry. It is expected to contain three files: `username`, `password`, and `registry`, each of which contains the corresponding credential. - save_dir: Path to the directory where files will be downloaded. - concurrency: The number of concurrent downloads to run. + save_dir (str): Path to the directory where files will be downloaded. + tokenizer_only (bool): If true, only download the tokenzier files. + concurrency (int): The number of concurrent downloads to run. """ if shutil.which(ORAS_CLI) is None: raise Exception( @@ -253,7 +266,8 @@ def _read_secrets_file(secret_file_path: str,): with open(config_file, 'r', encoding='utf-8') as f: configs = yaml.safe_load(f.read()) - path = configs['models'][model] + config_type = 'tokenizers' if tokenizer_only else 'models' + path = configs[config_type][model] registry = secrets['registry'] def get_oras_cmd(username: Optional[str] = None, diff --git a/scripts/misc/download_model.py b/scripts/misc/download_model.py index 1913267e20..13a63ce55e 100644 --- a/scripts/misc/download_model.py +++ b/scripts/misc/download_model.py @@ -7,10 +7,11 @@ python download_model.py hf --model mosaicml/mpt-7b --save-dir --token Download from ORAS registry: - python download_model.py oras --registry --path mosaicml/mpt-7b --save-dir + python download_model.py oras --model mosaicml/mpt-7b --config-file \ + --credentials-dir --save-dir Download from an HTTP file server: - python download_model.py http --host https://server.com --path mosaicml/mpt-7b --save-dir + python download_model.py http --url https://server.com/models/mosaicml/mpt-7b/ --save-dir Download from an HTTP file server with fallback to Hugging Face Hub: python download_model.py http --host https://server.com --path mosaicml/mpt-7b --save-dir \ @@ -56,6 +57,9 @@ def parse_args() -> argparse.Namespace: base_parser = argparse.ArgumentParser(add_help=False) base_parser.add_argument('--save-dir', type=str, required=True) + base_parser.add_argument('--tokenizer-only', + default=False, + action='store_true') # Add subparser for downloading from Hugging Face Hub. hf_parser = subparsers.add_parser('hf', parents=[base_parser]) @@ -85,6 +89,9 @@ def parse_args() -> argparse.Namespace: download_from = args.download_from if download_from == 'http': + if args.tokenizer_only: + raise ValueError( + 'tokenizer-only is not currently supported for http.') try: download_from_http_fileserver(args.url, args.save_dir, args.ignore_cert) @@ -109,7 +116,12 @@ def parse_args() -> argparse.Namespace: download_from_hf_hub(args.model, save_dir=args.save_dir, token=args.token, + tokenizer_only=args.tokenizer_only, prefer_safetensors=args.prefer_safetensors) elif download_from == 'oras': - download_from_oras(args.model, args.config_file, args.credentials_dir, - args.save_dir, args.concurrency) + download_from_oras(args.model, + args.config_file, + args.credentials_dir, + args.save_dir, + tokenizer_only=args.tokenizer_only, + concurrency=args.concurrency) diff --git a/tests/utils/test_model_download_utils.py b/tests/utils/test_model_download_utils.py index 471a39dcdb..14749bdcd9 100644 --- a/tests/utils/test_model_download_utils.py +++ b/tests/utils/test_model_download_utils.py @@ -110,6 +110,7 @@ def test_download_from_hf_hub_weights_pref(mock_list_repo_files: MagicMock, mock_snapshot_download.assert_called_once_with( test_repo_id, local_dir=save_dir, + allow_patterns=None, ignore_patterns=expected_ignore_patterns, token=None) From f2614a4a609e1575073c02422077ae4e89b82549 Mon Sep 17 00:00:00 2001 From: Anna Date: Mon, 22 Jan 2024 21:29:23 -0800 Subject: [PATCH 29/31] Foundational Model API eval wrapper (#849) * FMAPI model wrapper * add chat wrapper too * revert * end line * formatting * less verbose * better error messages --- .../models/inference_api_wrapper/__init__.py | 4 ++ .../models/inference_api_wrapper/fmapi.py | 72 +++++++++++++++++++ .../inference_api_wrapper/openai_causal_lm.py | 27 +++++-- llmfoundry/models/model_registry.py | 8 ++- 4 files changed, 104 insertions(+), 7 deletions(-) create mode 100644 llmfoundry/models/inference_api_wrapper/fmapi.py diff --git a/llmfoundry/models/inference_api_wrapper/__init__.py b/llmfoundry/models/inference_api_wrapper/__init__.py index 496abf2aa6..9bb2ece2b2 100644 --- a/llmfoundry/models/inference_api_wrapper/__init__.py +++ b/llmfoundry/models/inference_api_wrapper/__init__.py @@ -1,6 +1,8 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +from llmfoundry.models.inference_api_wrapper.fmapi import ( + FMAPICasualLMEvalWrapper, FMAPIChatAPIEvalWrapper) from llmfoundry.models.inference_api_wrapper.interface import \ InferenceAPIEvalWrapper from llmfoundry.models.inference_api_wrapper.openai_causal_lm import ( @@ -10,4 +12,6 @@ 'OpenAICausalLMEvalWrapper', 'OpenAIChatAPIEvalWrapper', 'InferenceAPIEvalWrapper', + 'FMAPICasualLMEvalWrapper', + 'FMAPIChatAPIEvalWrapper', ] diff --git a/llmfoundry/models/inference_api_wrapper/fmapi.py b/llmfoundry/models/inference_api_wrapper/fmapi.py new file mode 100644 index 0000000000..867b3c272e --- /dev/null +++ b/llmfoundry/models/inference_api_wrapper/fmapi.py @@ -0,0 +1,72 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +import time +from typing import Dict + +import requests +from transformers import AutoTokenizer + +from llmfoundry.models.inference_api_wrapper.openai_causal_lm import ( + OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper, OpenAIEvalInterface) + +__all__ = [ + 'FMAPICasualLMEvalWrapper', + 'FMAPIChatAPIEvalWrapper', +] + +log = logging.getLogger(__name__) + + +def block_until_ready(base_url: str): + """Block until the endpoint is ready.""" + sleep_s = 5 + timout_s = 5 * 60 # At max, wait 5 minutes + + ping_url = f'{base_url}/ping' + + waited_s = 0 + while True: + try: + requests.get(ping_url) + log.info(f'Endpoint {ping_url} is ready') + break + except requests.exceptions.ConnectionError: + log.debug( + f'Endpoint {ping_url} not ready yet. Sleeping {sleep_s} seconds' + ) + time.sleep(sleep_s) + waited_s += sleep_s + + if waited_s >= timout_s: + raise TimeoutError( + f'Endpoint {ping_url} did not become read after {waited_s:,} seconds, exiting' + ) + + +class FMAPIEvalInterface(OpenAIEvalInterface): + + def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer): + is_local = model_cfg.pop('local', False) + if is_local: + base_url = os.environ.get('MOSAICML_MODEL_ENDPOINT', + 'http://0.0.0.0:8080/v2') + model_cfg['base_url'] = base_url + block_until_ready(base_url) + + if 'base_url' not in model_cfg: + raise ValueError( + 'Must specify base_url or use local=True in model_cfg for FMAPIsEvalWrapper' + ) + + super().__init__(model_cfg, tokenizer) + + +class FMAPICasualLMEvalWrapper(FMAPIEvalInterface, OpenAICausalLMEvalWrapper): + """Databricks Foundational Model API wrapper for causal LM models.""" + + +class FMAPIChatAPIEvalWrapper(FMAPIEvalInterface, OpenAIChatAPIEvalWrapper): + """Databricks Foundational Model API wrapper for chat models.""" diff --git a/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py b/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py index 39de2ba59c..587dd179bd 100644 --- a/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py +++ b/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py @@ -36,9 +36,6 @@ class OpenAIEvalInterface(InferenceAPIEvalWrapper): def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: super().__init__(model_cfg, tokenizer) - assert os.getenv( - 'OPENAI_API_KEY' - ) is not None, 'No OpenAI API Key found. Ensure it is saved as an environmental variable called OPENAI_API_KEY.' try: import openai except ImportError as e: @@ -46,8 +43,28 @@ def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: extra_deps_group='openai', conda_package='openai', conda_channel='conda-forge') from e - self.client = openai.OpenAI() - self.model_name = model_cfg['version'] + + api_key = os.environ.get('OPENAI_API_KEY') + base_url = model_cfg.get('base_url') + if base_url is None: + # Using OpenAI default, where the API key is required + if api_key is None: + raise ValueError( + 'No OpenAI API Key found. Ensure it is saved as an environmental variable called OPENAI_API_KEY.' + ) + + else: + # Using a custom base URL, where the API key may not be required + log.info( + f'Making request to custom base URL: {base_url}{"" if api_key is not None else " (no API key set)"}' + ) + api_key = 'placeholder' # This cannot be None + + self.client = openai.OpenAI(base_url=base_url, api_key=api_key) + if 'version' in model_cfg: + self.model_name = model_cfg['version'] + else: + self.model_name = model_cfg['name'] def generate_completion(self, prompt: str, num_tokens: int): raise NotImplementedError() diff --git a/llmfoundry/models/model_registry.py b/llmfoundry/models/model_registry.py index be09a69835..ff9942f5f6 100644 --- a/llmfoundry/models/model_registry.py +++ b/llmfoundry/models/model_registry.py @@ -3,7 +3,9 @@ from llmfoundry.models.hf import (ComposerHFCausalLM, ComposerHFPrefixLM, ComposerHFT5) -from llmfoundry.models.inference_api_wrapper import (OpenAICausalLMEvalWrapper, +from llmfoundry.models.inference_api_wrapper import (FMAPICasualLMEvalWrapper, + FMAPIChatAPIEvalWrapper, + OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper) from llmfoundry.models.mpt import ComposerMPTCausalLM @@ -13,5 +15,7 @@ 'hf_prefix_lm': ComposerHFPrefixLM, 'hf_t5': ComposerHFT5, 'openai_causal_lm': OpenAICausalLMEvalWrapper, - 'openai_chat': OpenAIChatAPIEvalWrapper + 'fmapi_causal_lm': FMAPICasualLMEvalWrapper, + 'openai_chat': OpenAIChatAPIEvalWrapper, + 'fmapi_chat': FMAPIChatAPIEvalWrapper, } From 36fcb5e59db73eb09bfcc93087ef12fb6ecc8975 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Tue, 23 Jan 2024 11:25:49 -0800 Subject: [PATCH 30/31] Add better error for non-empty local output folder in convert_text_to_mds.py (#891) --- scripts/data_prep/convert_text_to_mds.py | 4 + .../data_prep/test_convert_text_to_mds.py | 82 ++++++++++--------- 2 files changed, 47 insertions(+), 39 deletions(-) diff --git a/scripts/data_prep/convert_text_to_mds.py b/scripts/data_prep/convert_text_to_mds.py index 2218e575b2..bfd60b8ee1 100644 --- a/scripts/data_prep/convert_text_to_mds.py +++ b/scripts/data_prep/convert_text_to_mds.py @@ -385,6 +385,10 @@ def convert_text_to_mds( local_output_folder = tempfile.TemporaryDirectory( ).name if is_remote_output else output_folder + if os.path.isdir(output_folder) and len(os.listdir(output_folder)) > 0: + raise FileExistsError( + f'{output_folder=} is not empty. Please remove or empty it.') + if processes > 1: # Download and convert the text files in parallel args = get_task_args(object_names, local_output_folder, input_folder, diff --git a/tests/a_scripts/data_prep/test_convert_text_to_mds.py b/tests/a_scripts/data_prep/test_convert_text_to_mds.py index cc293a2cdd..3a00a8889f 100644 --- a/tests/a_scripts/data_prep/test_convert_text_to_mds.py +++ b/tests/a_scripts/data_prep/test_convert_text_to_mds.py @@ -3,6 +3,7 @@ import os import pathlib +import shutil from concurrent.futures import ProcessPoolExecutor from glob import glob from typing import Callable, Iterable, List @@ -55,23 +56,6 @@ def upload_object(self, object_name: str, filename: str): remote_file.write(local_file.read()) -def _call_convert_text_to_mds(processes: int, tokenizer_name: str, - concat_tokens: int) -> None: - convert_text_to_mds( - tokenizer_name=tokenizer_name, - output_folder=f's3://fake-test-output-path', - input_folder=f's3://fake-test-input-path', - concat_tokens=concat_tokens, - eos_text='', - bos_text='', - no_wrap=False, - compression='zstd', - processes=processes, - args_str='Namespace()', - reprocess=False, - ) - - # Mock starmap with no multiprocessing def _mock_map(func: Callable, args: Iterable) -> Iterable: for arg in args: @@ -107,9 +91,22 @@ def test_single_and_multi_process(merge_shard_groups: Mock, maybe_create_object_store_from_uri.return_value = mock_object_store parse_uri.return_value = ('s3', 'fake-test-bucket', str(remote_folder)) - _call_convert_text_to_mds(processes=processes, - tokenizer_name=tokenizer_name, - concat_tokens=concat_tokens) + def call_convert_text_to_mds() -> None: + convert_text_to_mds( + tokenizer_name=tokenizer_name, + output_folder=f's3://fake-test-output-path', + input_folder=f's3://fake-test-input-path', + concat_tokens=concat_tokens, + eos_text='', + bos_text='', + no_wrap=False, + compression='zstd', + processes=processes, + args_str='Namespace()', + reprocess=False, + ) + + call_convert_text_to_mds() # Check call counts assert download_and_convert.call_count == processes # called once per process @@ -131,9 +128,7 @@ def test_single_and_multi_process(merge_shard_groups: Mock, _assert_files_exist(prefix=remote_folder, files=['index.json', DONE_FILENAME] + shards) - _call_convert_text_to_mds(processes=processes, - tokenizer_name=tokenizer_name, - concat_tokens=concat_tokens) + call_convert_text_to_mds() # Check call counts assert download_and_convert.call_count == processes # No changes because we shoudn't reprocess @@ -146,9 +141,7 @@ def test_single_and_multi_process(merge_shard_groups: Mock, mock_object_store = Mock(wraps=object_store) maybe_create_object_store_from_uri.return_value = mock_object_store - _call_convert_text_to_mds(processes=processes, - tokenizer_name=tokenizer_name, - concat_tokens=concat_tokens) + call_convert_text_to_mds() # Check call counts assert download_and_convert.call_count == processes * 2 # called once per process @@ -187,31 +180,42 @@ def test_local_path(tmp_path: pathlib.Path): input_folder = tmp_path / 'input' output_folder = tmp_path / 'output' + def call_convert_text_to_mds(reprocess: bool): + convert_text_to_mds( + tokenizer_name='mosaicml/mpt-7b', + output_folder=str(output_folder), + input_folder=str(input_folder), + concat_tokens=1, + eos_text='', + bos_text='', + no_wrap=False, + compression='zstd', + processes=1, + args_str='Namespace()', + reprocess=reprocess, + ) + # Create input text data os.makedirs(input_folder, exist_ok=True) with open(input_folder / 'test.txt', 'w') as f: f.write('test') # Convert text data to mds - convert_text_to_mds( - tokenizer_name='mosaicml/mpt-7b', - output_folder=str(output_folder), - input_folder=str(input_folder), - concat_tokens=1, - eos_text='', - bos_text='', - no_wrap=False, - compression='zstd', - processes=1, - args_str='Namespace()', - reprocess=False, - ) + call_convert_text_to_mds(reprocess=False) # Make sure all the files exist as expected. assert os.path.exists(output_folder / '.text_to_mds_conversion_done') assert os.path.exists(output_folder / 'index.json') assert os.path.exists(output_folder / 'shard.00000.mds.zstd') + # Test reprocessing. + with pytest.raises(FileExistsError): + call_convert_text_to_mds(reprocess=True) + + shutil.rmtree(output_folder) + + call_convert_text_to_mds(reprocess=True) + def test_is_already_processed(tmp_path: pathlib.Path): tmp_path_str = str(tmp_path) From 4961436ff199837cdee223a1d47e7d0d258fb4fd Mon Sep 17 00:00:00 2001 From: Nicholas Garcia Date: Tue, 23 Jan 2024 14:14:46 -0800 Subject: [PATCH 31/31] Allow bool input for loggers (#897) * Allow bool input for loggers * Convert earlier on * Fix test case --- llmfoundry/utils/builders.py | 15 +++++---------- scripts/train/train.py | 3 ++- tests/utils/test_builders.py | 4 ++-- 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 75438b895e..29642381f8 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -219,21 +219,16 @@ def build_callback( def build_logger(name: str, kwargs: Dict[str, Any]) -> LoggerDestination: - kwargs_dict = { - k: v if isinstance(v, str) else om.to_container(v, resolve=True) - for k, v in kwargs.items() - } - if name == 'wandb': - return WandBLogger(**kwargs_dict) + return WandBLogger(**kwargs) elif name == 'tensorboard': - return TensorboardLogger(**kwargs_dict) + return TensorboardLogger(**kwargs) elif name == 'in_memory_logger': - return InMemoryLogger(**kwargs_dict) + return InMemoryLogger(**kwargs) elif name == 'mlflow': - return MLFlowLogger(**kwargs_dict) + return MLFlowLogger(**kwargs) elif name == 'inmemory': - return InMemoryLogger(**kwargs_dict) + return InMemoryLogger(**kwargs) else: raise ValueError(f'Not sure how to build logger: {name}') diff --git a/scripts/train/train.py b/scripts/train/train.py index c3da1f1d3c..638ad8aaea 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -278,7 +278,8 @@ def main(cfg: DictConfig) -> Trainer: logger_configs: Optional[DictConfig] = pop_config(cfg, 'loggers', must_exist=False, - default_value=None) + default_value=None, + convert=True) callback_configs: Optional[DictConfig] = pop_config(cfg, 'callbacks', must_exist=False, diff --git a/tests/utils/test_builders.py b/tests/utils/test_builders.py index 9be6630075..303afc9b7d 100644 --- a/tests/utils/test_builders.py +++ b/tests/utils/test_builders.py @@ -135,14 +135,14 @@ def test_build_logger(): with pytest.raises(ValueError): _ = build_logger('unknown', {}) - logger_cfg = DictConfig({ + logger_cfg = { 'project': 'foobar', 'init_kwargs': { 'config': { 'foo': 'bar', } } - }) + } wandb_logger = build_logger('wandb', logger_cfg) # type: ignore assert isinstance(wandb_logger, WandBLogger) assert wandb_logger.project == 'foobar'