diff --git a/.ci/FILE_HEADER b/.ci/FILE_HEADER index 22198520fd..e6d99f5d6f 100644 --- a/.ci/FILE_HEADER +++ b/.ci/FILE_HEADER @@ -1,2 +1,2 @@ -Copyright 2022 MosaicML LLM Foundry authors +Copyright 2024 MosaicML LLM Foundry authors SPDX-License-Identifier: Apache-2.0 diff --git a/.gitignore b/.gitignore index 989fb3af0c..d041a25c22 100644 --- a/.gitignore +++ b/.gitignore @@ -155,3 +155,4 @@ notebooks/ **/*.pt **/mlruns/* **/tokenizer-save-dir-*/** +**/.downloaded_finetuning/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d4c8cc699c..4a5bf1f6bb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -57,7 +57,7 @@ repos: - id: mixed-line-ending - id: trailing-whitespace - repo: https://github.com/Lucas-C/pre-commit-hooks - rev: v1.3.1 + rev: v1.5.4 hooks: - id: insert-license args: @@ -65,6 +65,7 @@ repos: - .ci/FILE_HEADER - --comment-style - '#' + - --allow-past-years types: [python] - repo: https://github.com/PyCQA/docformatter rev: v1.5.0 diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 93bcbe3bcb..f019303808 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -166,31 +166,21 @@ def get_eval_parameters( def validate_interval(interval: Union[str, int, Time], save_interval: Union[str, int, Time]) -> Time: - if isinstance(save_interval, str): - new_save_interval: Time = Time.from_timestring(save_interval) - elif isinstance(save_interval, int): - new_save_interval: Time = Time(save_interval, TimeUnit.EPOCH) - else: - new_save_interval: Time = save_interval - if isinstance(interval, str): - result: Time = Time.from_timestring(interval) - elif isinstance(interval, int): - result: Time = Time(interval, TimeUnit.EPOCH) - else: - result: Time = interval + new_save_interval = Time.from_input(save_interval, TimeUnit.EPOCH) + async_interval = Time.from_input(interval, TimeUnit.EPOCH) - if new_save_interval.unit != result.unit: + if new_save_interval.unit != async_interval.unit: raise ValueError( 'Save interval and async eval interval must be in the same unit') - if result < new_save_interval: + if async_interval < new_save_interval: raise ValueError( 'Async eval interval must be equal or greater (less frequent) than save interval' ) - if result.value % new_save_interval.value != 0: + if async_interval.value % new_save_interval.value != 0: raise ValueError( 'Async eval interval must be a multiple of save interval') - return result + return async_interval def validate_eval_run_config( diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 491d510188..b50d81d09e 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -96,14 +96,10 @@ def __init__( self.huggingface_folder_name_fstr = os.path.join( 'huggingface', huggingface_folder_name) - if isinstance(save_interval, str): - save_interval = Time.from_timestring(save_interval) - if isinstance(save_interval, int): - save_interval = Time(save_interval, TimeUnit.EPOCH) - - self.save_interval: Time = save_interval + self.save_interval: Time = Time.from_input(save_interval, + TimeUnit.EPOCH) self.check_interval = create_interval_scheduler( - save_interval, include_end_of_training=True) + self.save_interval, include_end_of_training=True) self.remote_ud = maybe_create_remote_uploader_downloader_from_uri( save_folder, loggers=[]) if self.remote_ud is not None: @@ -245,11 +241,16 @@ def _save_checkpoint(self, state: State, logger: Logger): ) if self.remote_ud is not None: - log.info(f'Uploading HuggingFace formatted checkpoint') for filename in os.listdir(temp_save_dir): + remote_file_name = os.path.join(save_dir, filename) + remote_file_uri = self.remote_ud.remote_backend.get_uri( + remote_file_name) + log.info( + f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}' + ) self.remote_ud.upload_file( state=state, - remote_file_name=os.path.join(save_dir, filename), + remote_file_name=remote_file_name, file_path=Path(os.path.join(temp_save_dir, filename)), overwrite=self.overwrite, diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 7a29d1dfed..4e1c3bbf9f 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -4,7 +4,6 @@ import os from typing import Tuple, Union -import datasets as hf_datasets import torch from composer.core.data_spec import DataSpec from composer.utils import dist, get_file, parse_uri @@ -13,7 +12,9 @@ from transformers import PreTrainedTokenizerBase from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator -from llmfoundry.data.finetuning.tasks import dataset_constructor +from llmfoundry.data.finetuning.tasks import (DOWNLOADED_FT_DATASETS_DIRPATH, + SUPPORTED_EXTENSIONS, + dataset_constructor) from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio from llmfoundry.data.text_data import get_tokens_per_batch_func @@ -122,8 +123,13 @@ def build_finetuning_dataloader(cfg: DictConfig, if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token + collate_fn, dataloader_batch_size = _build_collate_fn( + cfg, tokenizer, device_batch_size) + dataset = None # for pyright + sampler = None if cfg.dataset.get('remote') is not None: + # Build streaming dataloader dataset = dataset_constructor.build_from_streaming( tokenizer=tokenizer, local=cfg.dataset.local, @@ -148,40 +154,45 @@ def build_finetuning_dataloader(cfg: DictConfig, batching_method=cfg.dataset.get('batching_method', 'random'), ) - collate_fn, dataloader_batch_size = _build_collate_fn( - cfg, tokenizer, device_batch_size) - - dl = DataLoader( - dataset, - collate_fn=collate_fn, - batch_size=dataloader_batch_size, - drop_last=cfg.drop_last, - num_workers=cfg.num_workers, - pin_memory=cfg.get('pin_memory', True), - prefetch_factor=cfg.get('prefetch_factor', 2), - persistent_workers=cfg.get('persistent_workers', True), - timeout=cfg.get('timeout', 0), - ) - else: - backend, _, _ = parse_uri(cfg.dataset.hf_name) + # Build HF dataloader + dataset_name_or_path = cfg.dataset.hf_name + split = cfg.dataset.get('split') + + # If dataset is a remote path, download it first. + backend, _, _ = parse_uri(dataset_name_or_path) if backend not in ['', None]: - if cfg.dataset.get('split') is None: + if split is None: raise ValueError( 'When using a HuggingFace dataset from a URL, you must set the ' + \ '`split` key in the dataset config.' ) - dataset = _build_hf_dataset_from_remote(cfg, tokenizer) + # HF datasets does not support a split with dashes, so we replace dashes + # with underscores. + split = split.replace('-', '_') + dataset_name_or_path = _download_remote_hf_dataset( + remote_path=dataset_name_or_path, split=split) + + # Get the preprocessing function. + proto_preprocessing_fn = cfg.dataset.get('preprocessing_fn') + if isinstance(proto_preprocessing_fn, (dict, DictConfig)): + preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_dict( + dict(proto_preprocessing_fn)) else: - dataset = dataset_constructor.build_from_hf( - cfg.dataset, - max_seq_len=cfg.dataset.max_seq_len, - tokenizer=tokenizer, - ) + preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_str( + proto_preprocessing_fn, dataset_name_or_path) - collate_fn, dataloader_batch_size = _build_collate_fn( - cfg, tokenizer, device_batch_size) + # Build dataset from HF. + dataset = dataset_constructor.build_from_hf( + dataset_name=dataset_name_or_path, + split=split, + safe_load=cfg.dataset.get('safe_load', False), + max_seq_len=cfg.dataset.max_seq_len, + preprocessing_fn=preprocessing_fn, + tokenizer=tokenizer, + hf_kwargs=cfg.dataset.get('hf_kwargs', {})) + # Ensure dataset is large enough. if cfg.drop_last: world_size = dist.get_world_size() minimum_dataset_size = world_size * dataloader_batch_size @@ -189,7 +200,7 @@ def build_finetuning_dataloader(cfg: DictConfig, full_dataset_size = len(dataset) if full_dataset_size < minimum_dataset_size: raise ValueError( - f'Your dataset (name={cfg.dataset.hf_name}, split={cfg.dataset.split}) ' + f'Your dataset (name={cfg.dataset.hf_name}, split={split}) ' + f'has {full_dataset_size} samples, but your minimum batch size ' + @@ -199,22 +210,24 @@ def build_finetuning_dataloader(cfg: DictConfig, + f'of samples in your dataset to at least {minimum_dataset_size}.' ) - - assert dataset is not None - dl = DataLoader( - dataset, - collate_fn=collate_fn, - batch_size=dataloader_batch_size, - drop_last=cfg.drop_last, - sampler=dist.get_sampler(dataset, - drop_last=cfg.drop_last, - shuffle=cfg.dataset.shuffle), - num_workers=cfg.num_workers, - pin_memory=cfg.get('pin_memory', True), - prefetch_factor=cfg.get('prefetch_factor', 2), - persistent_workers=cfg.get('persistent_workers', True), - timeout=cfg.get('timeout', 0), - ) + # Initialize sampler. + sampler = dist.get_sampler(dataset, + drop_last=cfg.drop_last, + shuffle=cfg.dataset.shuffle) + + assert dataset is not None # for pyright + dl = DataLoader( + dataset, + collate_fn=collate_fn, + batch_size=dataloader_batch_size, + drop_last=cfg.drop_last, + sampler=sampler, + num_workers=cfg.num_workers, + pin_memory=cfg.get('pin_memory', True), + prefetch_factor=cfg.get('prefetch_factor', 2), + persistent_workers=cfg.get('persistent_workers', True), + timeout=cfg.get('timeout', 0), + ) token_counting_func = get_tokens_per_batch_func() @@ -250,7 +263,7 @@ def _validate_config(dataset_cfg: DictConfig) -> None: ) elif dataset_cfg.get('remote') is not None: # Using the streaming dataset codepath - illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn'] + illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn', 'safe_load'] discovered_illegal_keys = [] for key in illegal_keys: if dataset_cfg.get(key) is not None: @@ -275,11 +288,8 @@ def _validate_config(dataset_cfg: DictConfig) -> None: ) -def _build_hf_dataset_from_remote( - cfg: DictConfig, tokenizer: PreTrainedTokenizerBase -) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset, - hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]: - """Builds a dataset from a remote object store. +def _download_remote_hf_dataset(remote_path: str, split: str) -> str: + """Downloads a dataset from a remote object store. This function supports 'jsonl', 'csv', and 'parquet' file formats for the dataset. It will attempt to download the dataset, then once it is downloaded, convert it into HuggingFace ``datasets`` format, and then return this @@ -290,38 +300,26 @@ def _build_hf_dataset_from_remote( completed, the function removes the signal file. Args: - cfg (DictConfig): The configuration dictionary containing the necessary parameters to load the dataset. - This includes: - - dataset.hf_name: The path of the HuggingFace dataset to download. - - dataset.split: The dataset split to download (e.g., 'train', 'validation', 'test'). - - dataset.max_seq_len: The maximum sequence length for tokenizing the dataset. - - tokenizer (Tokenizer): The tokenizer to be used to tokenize the dataset. + hf_name (str): The path of the HuggingFace dataset to download. + split (str): The dataset split to download (e.g., 'train', 'validation', 'test'). Returns: - Dataset: A HuggingFace dataset built from the remote file, prepared and tokenized for fine-tuning the model. + A local directory path where the dataset files are stored. Raises: FileNotFoundError: Raised if the dataset file cannot be found with any of the supported extensions. """ - supported_extensions = ['jsonl', 'csv', 'parquet'] - # HF datasets does not support a split with dashes, so we replace dashes - # with underscores in the destination split. - destination_split = cfg.dataset.split.replace('-', '_') finetune_dir = os.path.join( - os.path.dirname( - os.path.dirname(os.path.dirname(os.path.realpath(__file__)))), - 'downloaded_finetuning', - destination_split if destination_split != 'data' else 'data_not', + DOWNLOADED_FT_DATASETS_DIRPATH, + split if split != 'data' else 'data_not', ) os.makedirs(finetune_dir, exist_ok=True) - for extension in supported_extensions: - name = f'{cfg.dataset.hf_name.strip("/")}/{cfg.dataset.split}.{extension}' + for extension in SUPPORTED_EXTENSIONS: + name = f'{remote_path.strip("/")}/{split}{extension}' destination = str( os.path.abspath( - os.path.join( - finetune_dir, 'data', - f'{destination_split}-00000-of-00001.{extension}'))) + os.path.join(finetune_dir, 'data', + f'{split}-00000-of-00001{extension}'))) # Since we don't know exactly what the extension will be, since it is one of a list # use a signal file to wait for instead of the desired file @@ -331,14 +329,14 @@ def _build_hf_dataset_from_remote( try: get_file(path=name, destination=destination, overwrite=True) except FileNotFoundError as e: - if extension == supported_extensions[-1]: + if extension == SUPPORTED_EXTENSIONS[-1]: files_searched = [ - f'{cfg.dataset.hf_name}/{cfg.dataset.split}.{ext}' - for ext in supported_extensions + f'{cfg.dataset.hf_name}/{cfg.dataset.split}{ext}' + for ext in SUPPORTED_EXTENSIONS ] raise FileNotFoundError( f'Could not find a file with any of ' + \ - f'the supported extensions: {supported_extensions}\n' + \ + f'the supported extensions: {SUPPORTED_EXTENSIONS}\n' + \ f'at {files_searched}' ) from e else: @@ -350,25 +348,18 @@ def _build_hf_dataset_from_remote( with open(signal_file_path, 'wb') as f: f.write(b'local_rank0_completed_download') - # Avoid the collective call until the local rank zero has finished trying to download the checkpoint + # Avoid the collective call until the local rank zero has finished trying to download the dataset # so that we don't timeout for large downloads. This syncs all processes on the node with dist.local_rank_zero_download_and_wait(signal_file_path): - # Then, wait to ensure every node has finished downloading the checkpoint + # Then, wait to ensure every node has finished trying to download the dataset dist.barrier() # clean up signal file if dist.get_local_rank() == 0: os.remove(signal_file_path) dist.barrier() - - cfg.dataset.hf_name = finetune_dir - log.info(cfg.dataset) - dataset = dataset_constructor.build_from_hf( - cfg.dataset, - max_seq_len=cfg.dataset.max_seq_len, - tokenizer=tokenizer, - ) - return dataset + break + return finetune_dir def _build_collate_fn( diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 21c3558b2d..e61d138c41 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -35,11 +35,12 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: import logging import os import warnings +from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union import datasets as hf_datasets +import huggingface_hub as hf_hub from composer.utils import dist -from omegaconf import DictConfig from streaming import StreamingDataset from transformers import PreTrainedTokenizerBase @@ -51,6 +52,22 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: _ALLOWED_RESPONSE_KEYS = {'response', 'completion'} _ALLOWED_PROMPT_KEYS = {'prompt'} +DOWNLOADED_FT_DATASETS_DIRPATH = os.path.abspath( + os.path.join(os.path.realpath(__file__), os.pardir, os.pardir, os.pardir, + '.downloaded_finetuning')) +SUPPORTED_EXTENSIONS = ['.csv', '.jsonl', '.parquet'] + + +def _is_empty_or_nonexistent(dirpath: str) -> bool: + """Check if a directory is empty or non-existent. + + Args: + dirpath (str): Directory path to check. + + Returns + True if directory is empty or non-existent. False otherwise. + """ + return not os.path.isdir(dirpath) or len(os.listdir(dirpath)) == 0 def _tokenize_formatted_example( @@ -241,8 +258,9 @@ def print_registered_tasks(self) -> None: log.info('\n'.join(tasks)) def get_preprocessing_fn_from_dict( - self, mapping: Union[Dict, DictConfig] - ) -> Callable[[Dict[str, Any]], Dict[str, str]]: + self, + mapping: Dict[str, + str]) -> Callable[[Dict[str, Any]], Dict[str, str]]: """Get a preprocessing function from a dictionary. The dictionary maps column names in the dataset to "prompt" and "response". @@ -327,8 +345,10 @@ def get_preprocessing_fn_from_str( return preprocessing_fn def build_from_hf( - self, cfg: DictConfig, max_seq_len: int, - tokenizer: PreTrainedTokenizerBase + self, dataset_name: str, split: Optional[str], safe_load: bool, + max_seq_len: int, preprocessing_fn: Optional[Callable[[dict[str, Any]], + dict[str, str]]], + tokenizer: PreTrainedTokenizerBase, hf_kwargs: Dict[str, Any] ) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset, hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]: """Load a HuggingFace Datasets, preprocess, and tokenize. @@ -343,20 +363,6 @@ def build_from_hf( Returns: Dataset: The tokenized dataset. """ - dataset_name = cfg.hf_name - # HF datasets does not support a split with dashes,so we replace split - # dashes with underscore. - split = cfg.split.replace('-', '_') - kwargs = cfg.get('hf_kwargs', {}) - proto_preprocessing_fn = cfg.get('preprocessing_fn') - if isinstance(proto_preprocessing_fn, dict) or isinstance( - proto_preprocessing_fn, DictConfig): - preprocessing_fn = self.get_preprocessing_fn_from_dict( - proto_preprocessing_fn) - else: - preprocessing_fn = self.get_preprocessing_fn_from_str( - proto_preprocessing_fn, dataset_name) - signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_data_prep_completed' # Non local rank 0 ranks will wait here for local rank 0 to finish the data processing. @@ -379,9 +385,47 @@ def build_from_hf( error: Optional[Exception] = None filtered_dataset = None try: + if safe_load: + if not os.path.isdir(dataset_name): + # dataset_name is not a local dir path, download if needed. + local_dataset_dir = os.path.join( + DOWNLOADED_FT_DATASETS_DIRPATH, dataset_name) + + if _is_empty_or_nonexistent(dirpath=local_dataset_dir): + # Safely load a dataset from HF Hub with restricted file types. + hf_hub.snapshot_download( + dataset_name, + repo_type='dataset', + allow_patterns=[ + '*' + ext for ext in SUPPORTED_EXTENSIONS + ], + token=hf_kwargs.get('token', None), + revision=hf_kwargs.get('revision', None), + local_dir_use_symlinks=False, + local_dir=local_dataset_dir) + if _is_empty_or_nonexistent(dirpath=local_dataset_dir): + raise FileNotFoundError( + f'safe_load is set to True. No data files with safe extensions {SUPPORTED_EXTENSIONS} ' + + f'found for dataset {dataset_name}. ') + # Set dataset_name to the downloaded location. + dataset_name = local_dataset_dir + + # dataset_name is a local dir path. Use the abspath to prevent confusion. + dataset_name = os.path.abspath(dataset_name) + + # Ensure that the local dir contains only allowed file types. + dataset_files = [ + f for _, _, files in os.walk(dataset_name) for f in files + ] + if not all( + Path(f).suffix in SUPPORTED_EXTENSIONS + for f in dataset_files): + raise ValueError( + f'Dataset at local path {dataset_name} contains invalid file types. ' + + f'Allowed file types are: {SUPPORTED_EXTENSIONS}') dataset = hf_datasets.load_dataset(dataset_name, split=split, - **kwargs) + **hf_kwargs) def dataset_mapper(example: Dict): if preprocessing_fn is not None: diff --git a/llmfoundry/models/inference_api_wrapper/__init__.py b/llmfoundry/models/inference_api_wrapper/__init__.py index 496abf2aa6..9bb2ece2b2 100644 --- a/llmfoundry/models/inference_api_wrapper/__init__.py +++ b/llmfoundry/models/inference_api_wrapper/__init__.py @@ -1,6 +1,8 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +from llmfoundry.models.inference_api_wrapper.fmapi import ( + FMAPICasualLMEvalWrapper, FMAPIChatAPIEvalWrapper) from llmfoundry.models.inference_api_wrapper.interface import \ InferenceAPIEvalWrapper from llmfoundry.models.inference_api_wrapper.openai_causal_lm import ( @@ -10,4 +12,6 @@ 'OpenAICausalLMEvalWrapper', 'OpenAIChatAPIEvalWrapper', 'InferenceAPIEvalWrapper', + 'FMAPICasualLMEvalWrapper', + 'FMAPIChatAPIEvalWrapper', ] diff --git a/llmfoundry/models/inference_api_wrapper/fmapi.py b/llmfoundry/models/inference_api_wrapper/fmapi.py new file mode 100644 index 0000000000..867b3c272e --- /dev/null +++ b/llmfoundry/models/inference_api_wrapper/fmapi.py @@ -0,0 +1,72 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +import time +from typing import Dict + +import requests +from transformers import AutoTokenizer + +from llmfoundry.models.inference_api_wrapper.openai_causal_lm import ( + OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper, OpenAIEvalInterface) + +__all__ = [ + 'FMAPICasualLMEvalWrapper', + 'FMAPIChatAPIEvalWrapper', +] + +log = logging.getLogger(__name__) + + +def block_until_ready(base_url: str): + """Block until the endpoint is ready.""" + sleep_s = 5 + timout_s = 5 * 60 # At max, wait 5 minutes + + ping_url = f'{base_url}/ping' + + waited_s = 0 + while True: + try: + requests.get(ping_url) + log.info(f'Endpoint {ping_url} is ready') + break + except requests.exceptions.ConnectionError: + log.debug( + f'Endpoint {ping_url} not ready yet. Sleeping {sleep_s} seconds' + ) + time.sleep(sleep_s) + waited_s += sleep_s + + if waited_s >= timout_s: + raise TimeoutError( + f'Endpoint {ping_url} did not become read after {waited_s:,} seconds, exiting' + ) + + +class FMAPIEvalInterface(OpenAIEvalInterface): + + def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer): + is_local = model_cfg.pop('local', False) + if is_local: + base_url = os.environ.get('MOSAICML_MODEL_ENDPOINT', + 'http://0.0.0.0:8080/v2') + model_cfg['base_url'] = base_url + block_until_ready(base_url) + + if 'base_url' not in model_cfg: + raise ValueError( + 'Must specify base_url or use local=True in model_cfg for FMAPIsEvalWrapper' + ) + + super().__init__(model_cfg, tokenizer) + + +class FMAPICasualLMEvalWrapper(FMAPIEvalInterface, OpenAICausalLMEvalWrapper): + """Databricks Foundational Model API wrapper for causal LM models.""" + + +class FMAPIChatAPIEvalWrapper(FMAPIEvalInterface, OpenAIChatAPIEvalWrapper): + """Databricks Foundational Model API wrapper for chat models.""" diff --git a/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py b/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py index 39de2ba59c..587dd179bd 100644 --- a/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py +++ b/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py @@ -36,9 +36,6 @@ class OpenAIEvalInterface(InferenceAPIEvalWrapper): def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: super().__init__(model_cfg, tokenizer) - assert os.getenv( - 'OPENAI_API_KEY' - ) is not None, 'No OpenAI API Key found. Ensure it is saved as an environmental variable called OPENAI_API_KEY.' try: import openai except ImportError as e: @@ -46,8 +43,28 @@ def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: extra_deps_group='openai', conda_package='openai', conda_channel='conda-forge') from e - self.client = openai.OpenAI() - self.model_name = model_cfg['version'] + + api_key = os.environ.get('OPENAI_API_KEY') + base_url = model_cfg.get('base_url') + if base_url is None: + # Using OpenAI default, where the API key is required + if api_key is None: + raise ValueError( + 'No OpenAI API Key found. Ensure it is saved as an environmental variable called OPENAI_API_KEY.' + ) + + else: + # Using a custom base URL, where the API key may not be required + log.info( + f'Making request to custom base URL: {base_url}{"" if api_key is not None else " (no API key set)"}' + ) + api_key = 'placeholder' # This cannot be None + + self.client = openai.OpenAI(base_url=base_url, api_key=api_key) + if 'version' in model_cfg: + self.model_name = model_cfg['version'] + else: + self.model_name = model_cfg['name'] def generate_completion(self, prompt: str, num_tokens: int): raise NotImplementedError() diff --git a/llmfoundry/models/layers/__init__.py b/llmfoundry/models/layers/__init__.py index 68aa0fe7fe..05350b059b 100644 --- a/llmfoundry/models/layers/__init__.py +++ b/llmfoundry/models/layers/__init__.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 from llmfoundry.models.layers.attention import ( - ATTN_CLASS_REGISTRY, MultiheadAttention, MultiQueryAttention, - attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn, - scaled_multihead_dot_product_attention, triton_flash_attn_fn) + ATTN_CLASS_REGISTRY, GroupedQueryAttention, MultiheadAttention, + MultiQueryAttention, attn_bias_shape, build_alibi_bias, build_attn_bias, + flash_attn_fn, scaled_multihead_dot_product_attention, triton_flash_attn_fn) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY @@ -17,6 +17,7 @@ 'triton_flash_attn_fn', 'MultiheadAttention', 'MultiQueryAttention', + 'GroupedQueryAttention', 'attn_bias_shape', 'build_attn_bias', 'build_alibi_bias', diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 2a90bf2f80..fecd79553f 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -39,6 +39,11 @@ def is_transformers_version_gte(hf_version: str) -> bool: return version.parse(transformers.__version__) >= version.parse(hf_version) +def check_alibi_support(attention_impl: str) -> bool: + return attention_impl != 'flash' or is_flash_v2_installed( + v2_version='v2.4.2') + + # Before importing any transformers models, we need to disable transformers flash attention if # we are in an environment with flash attention version <2. Transformers hard errors on a not properly # gated import otherwise. @@ -223,11 +228,17 @@ def flash_attn_fn( training: bool = False, needs_weights: bool = False, multiquery: bool = False, - attention_mask_in_length: Optional[torch.Tensor] = None, should_repeat_kv_for_gqa: Optional[bool] = True, sliding_window_size: int = -1, + alibi_slopes: Optional[torch.Tensor] = None, + flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: + if key_padding_mask is not None: + raise ValueError('key_padding_mask should be None for flash attn.') + del key_padding_mask + if flash_attn_padding_info is None: + raise ValueError('flash_attn_padding_info is required for flash attn.') try: from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip except: @@ -261,25 +272,24 @@ def flash_attn_fn( batch_size, seqlen = query.shape[:2] - if attention_mask_in_length is None: - if key_padding_mask is None: - key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) - query_padding_mask = key_padding_mask[:, -query.size(1):] - unpadding_function = bert_padding.unpad_input - else: - key_padding_mask = attention_mask_in_length - query_padding_mask = attention_mask_in_length - unpadding_function = bert_padding.unpad_input_for_concatenated_sequences + indices_q = flash_attn_padding_info['indices_q'] + indices_k = flash_attn_padding_info['indices_k'] + indices_v = flash_attn_padding_info['indices_v'] + cu_seqlens_q = flash_attn_padding_info['cu_seqlens_q'] + cu_seqlens_k = flash_attn_padding_info['cu_seqlens_k'] + max_seqlen_q = flash_attn_padding_info['max_seqlen_q'] + max_seqlen_k = flash_attn_padding_info['max_seqlen_k'] - query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function( - query, query_padding_mask) + query_unpad = bert_padding.index_first_axis( + rearrange(query, 'b s ... -> (b s) ...'), indices_q) query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads) - key_unpad, _, cu_seqlens_k, max_seqlen_k = unpadding_function( - key, key_padding_mask) + key_unpad = bert_padding.index_first_axis( + rearrange(key, 'b s ... -> (b s) ...'), indices_k) key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) - value_unpad, _, _, _ = unpadding_function(value, key_padding_mask) + value_unpad = bert_padding.index_first_axis( + rearrange(value, 'b s ... -> (b s) ...'), indices_v) value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) if (kv_n_heads < n_heads) and (not is_flash_v2_installed()) and ( @@ -334,6 +344,12 @@ def flash_attn_fn( causal=reset_is_causal, return_attn_probs=needs_weights) elif is_flash_v2_installed(): + alibi_kwargs = {} + if check_alibi_support('flash'): + alibi_kwargs = {'alibi_slopes': alibi_slopes} + elif alibi_slopes is not None: + raise ValueError( + 'alibi_slopes is only supported for flash-attn>=2.4.2') output_unpad = flash_attn_interface.flash_attn_varlen_func( q=query_unpad, k=key_unpad, @@ -346,10 +362,11 @@ def flash_attn_fn( softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights, - window_size=(sliding_window_size, sliding_window_size)) + window_size=(sliding_window_size, sliding_window_size), + **alibi_kwargs) else: raise RuntimeError( - 'flash-attn==1.0.9 or flash-attn==2.3.6 is required.') + 'flash-attn==1.0.9 or flash-attn==2.4.2 is required.') output = bert_padding.pad_input( rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, @@ -586,7 +603,8 @@ def forward( rotary_emb_w_meta_info: Optional[dict] = None, is_causal: bool = True, needs_weights: bool = False, - attention_mask_in_length: Optional[torch.Tensor] = None, + alibi_slopes: Optional[torch.Tensor] = None, + flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: qkv = self.Wqkv(x) @@ -652,10 +670,12 @@ def forward( extra_attn_kwargs = {} if self.attn_impl == 'flash': + key_padding_mask = None extra_attn_kwargs = { - 'attention_mask_in_length': attention_mask_in_length, 'should_repeat_kv_for_gqa': not is_flash_v2_installed(), 'sliding_window_size': self.sliding_window_size, + 'alibi_slopes': alibi_slopes, + 'flash_attn_padding_info': flash_attn_padding_info, } context, attn_weights, past_key_value = self.attn_fn( @@ -805,7 +825,8 @@ def build_attn_bias( def gen_slopes(n_heads: int, alibi_bias_max: int = 8, - device: Optional[torch.device] = None) -> torch.Tensor: + device: Optional[torch.device] = None, + return_1d: bool = False) -> torch.Tensor: _n_heads = 2**math.ceil(math.log2(n_heads)) m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device) m = m.mul(alibi_bias_max / _n_heads) @@ -816,7 +837,8 @@ def gen_slopes(n_heads: int, # Huggingface and FasterTransformer calculate slopes normally, # then return this strided concatenation of slopes slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads] - + if return_1d: + return slopes return slopes.view(1, n_heads, 1, 1) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index c077ccb535..036a4e7cd2 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -122,7 +122,8 @@ def forward( attention_mask: Optional[torch.ByteTensor] = None, is_causal: bool = True, output_attentions: bool = False, - attention_mask_in_length: Optional[torch.Tensor] = None, + alibi_slopes: Optional[torch.Tensor] = None, + flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ torch.Tensor, torch.Tensor]]]: a = self.norm_1(x) @@ -134,7 +135,8 @@ def forward( attention_mask=attention_mask, is_causal=is_causal, needs_weights=output_attentions, - attention_mask_in_length=attention_mask_in_length, + alibi_slopes=alibi_slopes, + flash_attn_padding_info=flash_attn_padding_info, ) x = x + self.resid_attn_dropout(b) m = x diff --git a/llmfoundry/models/model_registry.py b/llmfoundry/models/model_registry.py index be09a69835..ff9942f5f6 100644 --- a/llmfoundry/models/model_registry.py +++ b/llmfoundry/models/model_registry.py @@ -3,7 +3,9 @@ from llmfoundry.models.hf import (ComposerHFCausalLM, ComposerHFPrefixLM, ComposerHFT5) -from llmfoundry.models.inference_api_wrapper import (OpenAICausalLMEvalWrapper, +from llmfoundry.models.inference_api_wrapper import (FMAPICasualLMEvalWrapper, + FMAPIChatAPIEvalWrapper, + OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper) from llmfoundry.models.mpt import ComposerMPTCausalLM @@ -13,5 +15,7 @@ 'hf_prefix_lm': ComposerHFPrefixLM, 'hf_t5': ComposerHFT5, 'openai_causal_lm': OpenAICausalLMEvalWrapper, - 'openai_chat': OpenAIChatAPIEvalWrapper + 'fmapi_causal_lm': FMAPICasualLMEvalWrapper, + 'openai_chat': OpenAIChatAPIEvalWrapper, + 'fmapi_chat': FMAPIChatAPIEvalWrapper, } diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index ae4754108c..5474529277 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -8,7 +8,8 @@ from transformers import PretrainedConfig -from llmfoundry.models.layers.attention import is_flash_v2_installed +from llmfoundry.models.layers.attention import (check_alibi_support, + is_flash_v2_installed) from llmfoundry.models.layers.blocks import attn_config_defaults # NOTE: All utils are imported directly even if unused so that @@ -220,11 +221,11 @@ def _validate_config(self) -> None: 'attn_impl'] not in ['torch', 'triton']: raise NotImplementedError( 'prefix_lm only implemented with torch and triton attention.') - if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in [ - 'torch', 'triton' - ]: + if self.attn_config['alibi'] and not check_alibi_support( + self.attn_config['attn_impl']): raise NotImplementedError( - 'alibi only implemented with torch and triton attention.') + 'alibi only implemented with torch, triton, and flash (v2.4.2 or higher) attention.' + ) if self.attn_config['attn_uses_sequence_id'] and not ( self.attn_config['attn_impl'] in ['torch', 'triton'] or (self.attn_config['attn_impl'] == 'flash' and diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 4c80b10ed9..2177124740 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -6,6 +6,8 @@ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py """ +from __future__ import annotations + import math import warnings from typing import (Any, Dict, List, Mapping, MutableMapping, Optional, Tuple, @@ -24,15 +26,23 @@ from composer.models import HuggingFaceModel from composer.utils import dist -from llmfoundry.models.layers.attention import is_flash_v2_installed +from llmfoundry.models.layers.attention import (is_flash_v1_installed, + is_flash_v2_installed) if is_flash_v2_installed(): try: # This try...except is needed because transformers requires it despite the 'if' statement above + from flash_attn import bert_padding from flash_attn.layers.rotary import \ RotaryEmbedding as DAILRotaryEmbedding except Exception as e: raise e +if is_flash_v1_installed(): + try: # This try...except is needed because transformers requires it despite the 'if' statement above + from flash_attn import bert_padding + except Exception as e: + raise e + from omegaconf import DictConfig from omegaconf import OmegaConf as om from transformers import PreTrainedModel, PreTrainedTokenizerBase @@ -47,7 +57,7 @@ from llmfoundry.models.layers.attention import (ATTN_CLASS_REGISTRY, attn_bias_shape, - build_attn_bias) + build_attn_bias, gen_slopes) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY @@ -216,6 +226,44 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int, return attention_mask_in_length +def gen_flash_attn_padding_info( + bsz: int, + S: int, + past_key_len: int, + device: torch.device, + attention_mask_in_length: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None): + flash_attn_padding_info = {} + if attention_mask_in_length is None: + key_padding_mask = attention_mask + if key_padding_mask is None: + key_padding_mask = torch.ones((bsz, past_key_len + S), + dtype=torch.bool, + device=device) + query_padding_mask = key_padding_mask[:, -S:] + unpadding_function = bert_padding.unpad_input + else: + key_padding_mask = attention_mask_in_length + query_padding_mask = attention_mask_in_length + unpadding_function = bert_padding.unpad_input_for_concatenated_sequences + + _, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function( + torch.empty(bsz, S, 1, device=device), query_padding_mask) + _, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function( + torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask) + _, indices_v, _, _ = unpadding_function( + torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask) + + flash_attn_padding_info['indices_q'] = indices_q + flash_attn_padding_info['indices_k'] = indices_k + flash_attn_padding_info['indices_v'] = indices_v + flash_attn_padding_info['cu_seqlens_q'] = cu_seqlens_q + flash_attn_padding_info['cu_seqlens_k'] = cu_seqlens_k + flash_attn_padding_info['max_seqlen_q'] = max_seqlen_q + flash_attn_padding_info['max_seqlen_k'] = max_seqlen_k + return flash_attn_padding_info + + def apply_sequence_id(attn_bias: torch.Tensor, sequence_id: torch.LongTensor, max_seq_len: int) -> torch.Tensor: seq_len = sequence_id.shape[-1] @@ -246,6 +294,14 @@ class MPTPreTrainedModel(PreTrainedModel): _no_split_modules = ['MPTBlock'] +def _fsdp_wrap_fn( + self: Union[MPTModel, MPTForCausalLM], + module: nn.Module, +) -> bool: + # FSDP Wrap function for MPT Models + return isinstance(module, MPTBlock) + + class MPTModel(MPTPreTrainedModel): def __init__(self, config: MPTConfig): @@ -330,12 +386,12 @@ def __init__(self, config: MPTConfig): for module in self.modules(): if hasattr(module, 'bias') and isinstance( module.bias, nn.Parameter): - log.info(f'Removing bias ({module.bias}) from {module}.') + log.info(f'Removing bias from {module=}.') module.register_parameter('bias', None) # For transformer engine if hasattr(module, 'use_bias'): - log.info(f'Setting use_bias=False for {module}.') + log.info(f'Setting use_bias=False for {module=}.') module.use_bias = False log.debug(self) @@ -515,10 +571,12 @@ def forward( raise ValueError( 'You cannot specify both input_ids and inputs_embeds.') elif input_ids is not None: + bsz = input_ids.size(0) S = input_ids.size(1) x = self.wte(input_ids) input_device = input_ids.device elif inputs_embeds is not None: + bsz = inputs_embeds.size(0) S = inputs_embeds.size(1) x = inputs_embeds input_device = inputs_embeds.device @@ -530,22 +588,23 @@ def forward( ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' rotary_emb_w_meta_info = None - if self.learned_pos_emb or self.rope: - past_position = 0 - if past_key_values is not None: - if len(past_key_values) != self.config.n_layers: - raise ValueError( - f'past_key_values must provide a past_key_value for each attention ' - + - f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).' - ) - # For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim). - # For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq). - # Here we shift position embedding using the `seq` dim of the past key - past_position = past_key_values[0][0].size(1) - if self.attn_impl == 'torch': - past_position = past_key_values[0][0].size(3) + past_position = 0 + if past_key_values is not None: + if len(past_key_values) != self.config.n_layers: + raise ValueError( + f'past_key_values must provide a past_key_value for each attention ' + + + f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).' + ) + # For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim). + # For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq). + # Here we shift position embedding using the `seq` dim of the past key + past_position = past_key_values[0][0].size(1) + if self.attn_impl == 'torch': + past_position = past_key_values[0][0].size(3) + + if self.learned_pos_emb or self.rope: if self.learned_pos_emb and (S + past_position > self.config.max_seq_len): raise ValueError( @@ -607,6 +666,14 @@ def forward( attn_uses_sequence_id=self.attn_uses_sequence_id, attn_impl=self.attn_impl, attention_mask=attention_mask) + + alibi_slopes = None # alibi_slopes will only be used by flash attention for ALiBi + if self.alibi and self.attn_impl == 'flash': + alibi_slopes = gen_slopes(n_heads=self.config.n_heads, + alibi_bias_max=self.alibi_bias_max, + device=x.device, + return_1d=True) + # initialize the past key values cache if it should be used presents = () if use_cache else None if use_cache and past_key_values is None: @@ -615,6 +682,12 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + flash_attn_padding_info = {} + if self.attn_impl == 'flash': + flash_attn_padding_info = gen_flash_attn_padding_info( + bsz, S, past_position, x.device, attention_mask_in_length, + attention_mask) + for b_idx, block in enumerate(self.blocks): if output_hidden_states: assert all_hidden_states is not None # pyright @@ -629,7 +702,8 @@ def forward( attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions), - attention_mask_in_length=attention_mask_in_length, + alibi_slopes=alibi_slopes, + flash_attn_padding_info=flash_attn_padding_info, ) if presents is not None: presents += (present,) @@ -664,7 +738,7 @@ def param_init_fn(self, module: nn.Module) -> None: # FSDP Wrap function def fsdp_wrap_fn(self, module: nn.Module) -> bool: - return isinstance(module, MPTBlock) + return _fsdp_wrap_fn(self, module) # Activation Checkpointing def activation_checkpointing_fn(self, module: nn.Module) -> bool: @@ -825,7 +899,7 @@ def param_init_fn(self, module: nn.Module) -> None: # FSDP Wrap function def fsdp_wrap_fn(self, module: nn.Module) -> bool: - return isinstance(module, MPTBlock) + return _fsdp_wrap_fn(self, module) # Activation Checkpointing def activation_checkpointing_fn(self, module: nn.Module) -> bool: diff --git a/llmfoundry/optim/scheduler.py b/llmfoundry/optim/scheduler.py index 4a6d21c873..5598db28a1 100644 --- a/llmfoundry/optim/scheduler.py +++ b/llmfoundry/optim/scheduler.py @@ -16,24 +16,19 @@ def _raise_if_units_dont_match(time: Union[str, Time], t_max: Union[str, Time], name: str) -> None: - if isinstance(time, str): - time = Time.from_timestring(time) - if isinstance(t_max, str): - t_max = Time.from_timestring(t_max) + new_time = Time.from_input(time) + new_t_max = Time.from_input(t_max) - assert not isinstance(time, str) and not isinstance(t_max, str) - - if time.unit != t_max.unit: - raise ValueError(f'{time.unit=} does not match {t_max.unit=}.') + if new_time.unit != new_t_max.unit: + raise ValueError( + f'{name} (unit {new_time.unit=}) must match max_duration unit ({new_t_max.unit=}).' + ) def _raise_if_units_dur(time: Union[str, Time], name: str) -> None: - if isinstance(time, str): - time = Time.from_timestring(time) - - assert not isinstance(time, str) + new_time = Time.from_input(time) - if time.unit == TimeUnit('dur'): + if new_time.unit == TimeUnit('dur'): raise ValueError(f'{name} cannot be in units of "dur".') diff --git a/llmfoundry/tokenizers/tiktoken.py b/llmfoundry/tokenizers/tiktoken.py index 342b5c2ecf..eaaf0da316 100644 --- a/llmfoundry/tokenizers/tiktoken.py +++ b/llmfoundry/tokenizers/tiktoken.py @@ -198,7 +198,7 @@ def default_chat_template(self): '{% else %}' "{{ '\n' + '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' }}" '{% endif %}' - '{% if (add_generation_prompt == true) %}' + '{% if (add_generation_prompt == true and loop.last) %}' "{{ '\n' + '<|im_start|>' + 'assistant' + '\n' }}" "{% elif (message['role'] == 'assistant') %}" '{{ eos_token }}' @@ -253,7 +253,10 @@ def _convert_token_to_id(self, token: str) -> Optional[int]: def _convert_id_to_token(self, index: int) -> Optional[str]: """Converts an index (integer) in a token (str) using the vocab.""" - return self.decoder.get(index) + # For tokens in either the gap in ids in the tokenizer, or beyond the range of the tokenizer, + # we return empty string. This matches the behavior of Hugging Face fast tokenizers, + # but not slow tokenizers. + return self.decoder.get(index, '') def convert_tokens_to_string(self, tokens: List[str]) -> str: """Converts a sequence of tokens (string) in a single string.""" diff --git a/llmfoundry/utils/__init__.py b/llmfoundry/utils/__init__.py index 7abe4dcf75..83af2153a2 100644 --- a/llmfoundry/utils/__init__.py +++ b/llmfoundry/utils/__init__.py @@ -12,7 +12,7 @@ log_config, pop_config, update_batch_size_info) from llmfoundry.utils.model_download_utils import ( - download_from_cache_server, download_from_hf_hub) + download_from_hf_hub, download_from_http_fileserver) except ImportError as e: raise ImportError( 'Please make sure to pip install . to get requirements for llm-foundry.' @@ -28,7 +28,7 @@ 'build_tokenizer', 'calculate_batch_size_info', 'convert_and_save_ft_weights', - 'download_from_cache_server', + 'download_from_http_fileserver', 'download_from_hf_hub', 'get_hf_tokenizer_from_composer_state_dict', 'update_batch_size_info', diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 57ed681a09..42f817b386 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -219,21 +219,16 @@ def build_callback( def build_logger(name: str, kwargs: Dict[str, Any]) -> LoggerDestination: - kwargs_dict = { - k: v if isinstance(v, str) else om.to_container(v, resolve=True) - for k, v in kwargs.items() - } - if name == 'wandb': - return WandBLogger(**kwargs_dict) + return WandBLogger(**kwargs) elif name == 'tensorboard': - return TensorboardLogger(**kwargs_dict) + return TensorboardLogger(**kwargs) elif name == 'in_memory_logger': - return InMemoryLogger(**kwargs_dict) + return InMemoryLogger(**kwargs) elif name == 'mlflow': - return MLFlowLogger(**kwargs_dict) + return MLFlowLogger(**kwargs) elif name == 'inmemory': - return InMemoryLogger(**kwargs_dict) + return InMemoryLogger(**kwargs) else: raise ValueError(f'Not sure how to build logger: {name}') @@ -243,8 +238,6 @@ def build_algorithm(name: str, kwargs: Dict[str, Any]) -> Algorithm: return algorithms.GradientClipping(**kwargs) elif name == 'alibi': return algorithms.Alibi(**kwargs) - elif name == 'fused_layernorm': - return algorithms.FusedLayerNorm(**kwargs) elif name == 'gated_linear_units': return algorithms.GatedLinearUnits(**kwargs) elif name == 'low_precision_layernorm': diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 55576eaba0..29d78a0770 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -120,18 +120,10 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): # Set defaults for mixed initialization fsdp_config.setdefault('use_orig_params', False) fsdp_config.setdefault('load_monolith_rank0_only', True) - # Always set `sync_module_states` to True when using hybrid sharding - if fsdp_config is not None and \ - fsdp_config.get('sharding_strategy', 'FULL_SHARD') in ['HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'] \ - and not fsdp_config.get('sync_module_states', False): - warnings.warn( - ('Setting `sync_module_states = True` for FSDP. This is required ' - 'when using hybrid sharding.')) - fsdp_config['sync_module_states'] = True - - # no mixed precision needed for weights when they're already 16 bits + + # No mixed precision needed for weights when they're already 16 bits master_dtype = model_cfg.get('master_weights_dtype') - small_dtypes = ('bf16', 'f16', 'float16', 'bfloat16', 'amp_fp16', + small_dtypes = ('bf16', 'fp16', 'float16', 'bfloat16', 'amp_fp16', 'amp_bf16') if fsdp_config and master_dtype in small_dtypes: reduce_dtype = None diff --git a/llmfoundry/utils/huggingface_hub_utils.py b/llmfoundry/utils/huggingface_hub_utils.py index a74ab1cc35..07a9c3900e 100644 --- a/llmfoundry/utils/huggingface_hub_utils.py +++ b/llmfoundry/utils/huggingface_hub_utils.py @@ -59,7 +59,7 @@ def process_file( folder_path: str, flatten_imports_prefix: Sequence[str], ) -> list[str]: - with open(file_path, 'r') as f: + with open(file_path, 'r', encoding='utf-8') as f: source = f.read() parent_module_name = None @@ -102,7 +102,7 @@ def process_file( if new_filename == '__init__.py': new_filename = file_path.split('/')[-2] + '.py' new_file_path = os.path.join(folder_path, new_filename) - with open(new_file_path, 'w') as f: + with open(new_file_path, 'w', encoding='utf-8') as f: assert new_tree is not None f.write(ast.unparse(new_tree)) diff --git a/llmfoundry/utils/model_download_utils.py b/llmfoundry/utils/model_download_utils.py index 2104455e0f..07c84a85c8 100644 --- a/llmfoundry/utils/model_download_utils.py +++ b/llmfoundry/utils/model_download_utils.py @@ -5,6 +5,8 @@ import copy import logging import os +import shutil +import subprocess import time import warnings from http import HTTPStatus @@ -14,6 +16,7 @@ import huggingface_hub as hf_hub import requests import tenacity +import yaml from bs4 import BeautifulSoup from requests.packages.urllib3.exceptions import InsecureRequestWarning from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME @@ -27,6 +30,15 @@ ] PYTORCH_WEIGHTS_PATTERN = 'pytorch_model*.bin*' SAFE_WEIGHTS_PATTERN = 'model*.safetensors*' +TOKENIZER_FILES = [ + 'special_tokens_map.json', + 'tokenizer.json', + 'tokenizer.model', + 'tokenizer_config.json', +] + +ORAS_PASSWD_PLACEHOLDER = '' +ORAS_CLI = 'oras' log = logging.getLogger(__name__) @@ -36,9 +48,10 @@ stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(min=1, max=10)) def download_from_hf_hub( - repo_id: str, - save_dir: Optional[str] = None, + model: str, + save_dir: str, prefer_safetensors: bool = True, + tokenizer_only: bool = False, token: Optional[str] = None, ): """Downloads model files from a Hugging Face Hub model repo. @@ -48,10 +61,10 @@ def download_from_hf_hub( Args: repo_id (str): The Hugging Face Hub repo ID. - save_dir (str, optional): The path to the directory where the model files will be downloaded. If `None`, reads - from the `HUGGINGFACE_HUB_CACHE` environment variable or uses the default Hugging Face Hub cache directory. + save_dir (str, optional): The local path to the directory where the model files will be downloaded. prefer_safetensors (bool): Whether to prefer Safetensors weights over PyTorch weights if both are available. Defaults to True. + tokenizer_only (bool): If true, only download tokenizer files. token (str, optional): The HuggingFace API token. If not provided, the token will be read from the `HUGGING_FACE_HUB_TOKEN` environment variable. @@ -59,7 +72,7 @@ def download_from_hf_hub( RepositoryNotFoundError: If the model repo doesn't exist or the token is unauthorized. ValueError: If the model repo doesn't contain any supported model weights. """ - repo_files = set(hf_hub.list_repo_files(repo_id)) + repo_files = set(hf_hub.list_repo_files(model)) # Ignore TensorFlow, TensorFlow 2, and Flax weights as they are not supported by Composer. ignore_patterns = copy.deepcopy(DEFAULT_IGNORE_PATTERNS) @@ -86,18 +99,21 @@ def download_from_hf_hub( log.info('Only pytorch available. Ignoring weights preference.') else: raise ValueError( - f'No supported model weights found in repo {repo_id}.' + + f'No supported model weights found in repo {model}.' + ' Please make sure the repo contains either safetensors or pytorch weights.' ) + allow_patterns = TOKENIZER_FILES if tokenizer_only else None + download_start = time.time() - hf_hub.snapshot_download(repo_id, - cache_dir=save_dir, + hf_hub.snapshot_download(model, + local_dir=save_dir, ignore_patterns=ignore_patterns, + allow_patterns=allow_patterns, token=token) download_duration = time.time() - download_start log.info( - f'Downloaded model {repo_id} from Hugging Face Hub in {download_duration} seconds' + f'Downloaded model {model} from Hugging Face Hub in {download_duration} seconds' ) @@ -140,6 +156,7 @@ def _recursive_download( RuntimeError: If the remote server returns a status code other than 200 OK or 401 Unauthorized. """ url = urljoin(base_url, path) + print(url) response = session.get(url, verify=(not ignore_cert)) if response.status_code == HTTPStatus.UNAUTHORIZED: @@ -156,7 +173,7 @@ def _recursive_download( ) # Assume that the URL points to a file if it does not end with a slash. - if not path.endswith('/'): + if not url.endswith('/'): save_path = os.path.join(save_dir, path) parent_dir = os.path.dirname(save_path) if not os.path.exists(parent_dir): @@ -171,6 +188,7 @@ def _recursive_download( # If the URL is a directory, the response should be an HTML directory listing that we can parse for additional links # to download. child_links = _extract_links_from_html(response.content.decode()) + print(child_links) for child_link in child_links: _recursive_download(session, base_url, @@ -183,53 +201,101 @@ def _recursive_download( (PermissionError, ValueError)), stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(min=1, max=10)) -def download_from_cache_server( - model_name: str, - cache_base_url: str, +def download_from_http_fileserver( + url: str, save_dir: str, - token: Optional[str] = None, ignore_cert: bool = False, ): - """Downloads Hugging Face models from a mirror file server. - - The file server is expected to store the files in the same structure as the Hugging Face cache - structure. See https://huggingface.co/docs/huggingface_hub/guides/manage-cache. + """Downloads files from a remote HTTP file server. Args: - model_name: The name of the model to download. This should be the same as the repository ID in the Hugging Face - Hub. - cache_base_url: The base URL of the cache file server. This function will attempt to download all of the blob - files from `//blobs/`, where `formatted_model_name` is equal to - `models/` with all slashes replaced with `--`. - save_dir: The directory to save the downloaded files to. - token: The Hugging Face API token. If not provided, the token will be read from the `HUGGING_FACE_HUB_TOKEN` - environment variable. - ignore_cert: Whether or not to ignore the validity of the SSL certificate of the remote server. Defaults to - False. + url (str): The base URL where the files are located. + save_dir (str): The directory to save downloaded files to. + ignore_cert (bool): Whether or not to ignore the validity of the SSL certificate of the remote server. + Defaults to False. WARNING: Setting this to true is *not* secure, as no certificate verification will be performed. """ - formatted_model_name = f'models/{model_name}'.replace('/', '--') with requests.Session() as session: - session.headers.update({'Authorization': f'Bearer {token}'}) - - download_start = time.time() - # Temporarily suppress noisy SSL certificate verification warnings if ignore_cert is set to True with warnings.catch_warnings(): if ignore_cert: warnings.simplefilter('ignore', category=InsecureRequestWarning) - # Only downloads the blobs in order to avoid downloading model files twice due to the - # symlnks in the Hugging Face cache structure: - _recursive_download( - session, - cache_base_url, - # Trailing slash to indicate directory - f'{formatted_model_name}/blobs/', - save_dir, - ignore_cert=ignore_cert, - ) - download_duration = time.time() - download_start - log.info( - f'Downloaded model {model_name} from cache server in {download_duration} seconds' + _recursive_download(session, + url, + '', + save_dir, + ignore_cert=ignore_cert) + + +def download_from_oras(model: str, + config_file: str, + credentials_dir: str, + save_dir: str, + tokenizer_only: bool = False, + concurrency: int = 10): + """Download from an OCI-compliant registry using oras. + + Args: + model (str): The name of the model to download. + config_file (str): Path to a YAML config file that maps model and tokenizer names to registry paths. + credentials_dir (str): Path to a directory containing credentials for the registry. It is expected to contain three + files: `username`, `password`, and `registry`, each of which contains the corresponding credential. + save_dir (str): Path to the directory where files will be downloaded. + tokenizer_only (bool): If true, only download the tokenzier files. + concurrency (int): The number of concurrent downloads to run. + """ + if shutil.which(ORAS_CLI) is None: + raise Exception( + f'oras cli command `{ORAS_CLI}` is not found. Please install oras: https://oras.land/docs/installation ' ) + + def _read_secrets_file(secret_file_path: str,): + try: + with open(secret_file_path, encoding='utf-8') as f: + return f.read().strip() + except Exception as error: + raise ValueError( + f'secrets file {secret_file_path} failed to be read') from error + + secrets = {} + for secret in ['username', 'password', 'registry']: + secrets[secret] = _read_secrets_file( + os.path.join(credentials_dir, secret)) + + with open(config_file, 'r', encoding='utf-8') as f: + configs = yaml.safe_load(f.read()) + + config_type = 'tokenizers' if tokenizer_only else 'models' + path = configs[config_type][model] + registry = secrets['registry'] + + def get_oras_cmd(username: Optional[str] = None, + password: Optional[str] = None): + cmd = [ + ORAS_CLI, + 'pull', + f'{registry}/{path}', + '-o', + save_dir, + '--verbose', + '--concurrency', + str(concurrency), + ] + if username is not None: + cmd.extend(['--username', username]) + if password is not None: + cmd.extend(['--password', password]) + + return cmd + + cmd_without_creds = get_oras_cmd() + log.info(f'CMD for oras cli to run: {" ".join(cmd_without_creds)}') + cmd_to_run = get_oras_cmd(username=secrets['username'], + password=secrets['password']) + try: + subprocess.run(cmd_to_run, check=True) + except subprocess.CalledProcessError as e: + # Intercept the error and replace the cmd, which may have sensitive info. + raise subprocess.CalledProcessError(e.returncode, cmd_without_creds, + e.output, e.stderr) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py new file mode 100644 index 0000000000..9d73cb0e9f --- /dev/null +++ b/scripts/data_prep/convert_delta_to_json.py @@ -0,0 +1,556 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +import re +import time +import urllib.parse +from argparse import ArgumentParser, Namespace +from collections import namedtuple +from concurrent.futures import ProcessPoolExecutor +from typing import Iterable, List, Optional, Tuple, Union +from uuid import uuid4 + +import google.protobuf.any_pb2 as any_pb2 +import lz4.frame +import pandas as pd +import pyarrow as pa +import pyspark.sql.connect.proto as pb2 +import pyspark.sql.connect.proto.cloud_pb2 as cloud_pb2 +import requests +from databricks import sql +from databricks.connect import DatabricksSession +from databricks.sdk import WorkspaceClient +from databricks.sql.client import Connection as Connection +from databricks.sql.client import Cursor as Cursor +from packaging import version +from pyspark.sql import SparkSession +from pyspark.sql.connect.client.core import SparkConnectClient +from pyspark.sql.connect.client.reattach import \ + ExecutePlanResponseReattachableIterator +from pyspark.sql.connect.dataframe import DataFrame +from pyspark.sql.dataframe import DataFrame as SparkDataFrame +from pyspark.sql.types import Row + +MINIMUM_DB_CONNECT_DBR_VERSION = '14.1' +MINIMUM_SQ_CONNECT_DBR_VERSION = '12.2' + +log = logging.getLogger(__name__) + +Result = namedtuple( + 'Result', ['url', 'row_count', 'compressed_size', 'uncompressed_size' + ]) # pyright: ignore + +# ``collect_as_cf`` is an addon new feature monkey patch on top of the DB Connect package. +# It allows the client to fetch the results in different formats from the server. +# To be able to use the code make sure this module is not overriden by DB Connect classes. + + +def to_cf(self: SparkConnectClient, + plan: pb2.Plan, + type: str = 'json') -> Tuple[List[Result], int, bool]: + """Executes the query plans and return as presigned URLS for cloud fetch. + + It can handle the current output formats that are supported by the server. + In contrast to the regular API methods of the client, this method does not + return the schema and drops all other responses. + + Args: + plan (pb2.Plan): The plan object to be executed by spark. + type (str): The output format of the result, supported formats are 'json', 'csv', and 'arrow'. + + Returns: + Tuple[List[Result], int, bool]: A tuple containing: + - A list of Result namedtuples, each containing a URL, row count, compressed size, + and uncompressed size of the part of the result. + - Total row count of all parts of the result. + - A boolean indicating whether the result has been truncated. + """ + req = self._execute_plan_request_with_metadata() + req.plan.CopyFrom(plan) + + # Add the request options + if type == 'json': + format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_JSON + elif type == 'csv': + format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_CSV + elif type == 'arrow': + format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_ARROW + else: + raise ValueError( + f'Only formats json, csv, and arrow are supported. Got invalid type {type}' + ) + + ro = cloud_pb2.ResultOptions( + type=cloud_pb2.ResultOptions.TYPE_CLOUD, + cloudOptions=cloud_pb2.ResultOptions.CloudOptions( + format=format, + useCompression=False, + )) + cloud_option = any_pb2.Any() + cloud_option.Pack(ro) + req.request_options.append( + pb2.ExecutePlanRequest.RequestOption(extension=cloud_option)) + + # Create the iterator + iterator = ExecutePlanResponseReattachableIterator(req, self._stub, + self._retry_policy, + self._builder.metadata()) + # Iterate over the response + result = [] + row_count = 0 + is_overflow = False + + for response in iterator: + if response.HasField('extension') and response.extension.Is( + cloud_pb2.CloudResultBatch.DESCRIPTOR): + batch = cloud_pb2.CloudResultBatch() + if not response.extension.Is(cloud_pb2.CloudResultBatch.DESCRIPTOR): + raise ValueError( + 'Response extension is not of type CloudResultBatch.') + response.extension.Unpack(batch) + result += [ + Result(b.url, b.row_count, b.compressed_size, + b.uncompressed_size) for b in batch.results + ] + row_count += sum(result.row_count for result in batch.results) + is_overflow |= batch.truncated + return result, row_count, is_overflow + + +SparkConnectClient.to_cf = to_cf # pyright: ignore + + +def collect_as_cf(self: DataFrame, + type: str = 'json') -> Tuple[List[Result], int, bool]: + """Collects DataFrame execution plan as presigned URLs. + + This method is a wrapper around the `to_cf` method of SparkConnectClient. It takes the + execution plan of the current DataFrame, converts it to a protocol buffer format, and then + uses the `to_cf` method to execute the plan and fetch results as presigned URLs. + + Args: + type (str): The output format of the result, supported formats are 'json', 'csv', and 'arrow'. + + Returns: + Tuple[List[Result], int, bool]: A tuple containing: + - A list of Result namedtuples, each containing a URL, row count, compressed size, + and uncompressed size of the part of the result. + - Total row count of all parts of the result. + - A boolean indicating whether the result is truncated or overflowed. + """ + query = self._plan.to_proto(self._session.client) # pyright: ignore + return self._session.client.to_cf(query, type) # pyright: ignore + + +DataFrame.collect_cf = collect_as_cf # pyright: ignore + + +def iterative_combine_jsons(json_directory: str, output_file: str) -> None: + """Combine jsonl files in json_directory into one big jsonl file. + + This function does not work for nested subdirectories. + + Args: + json_directory(str): directory containing the JSONL files + output_file(str): path to the output combined JSONL file + """ + json_files = [f for f in os.listdir(json_directory) if f.endswith('.jsonl')] + with open(output_file, 'w') as outfile: + for file_name in json_files: + with open(os.path.join(json_directory, file_name), 'r') as infile: + for line in infile: + outfile.write(line) + log.info('JSON files have been combined into a JSONL file.') + + +def run_query( + query: str, + method: str, + cursor: Optional[Cursor] = None, + spark: Optional[SparkSession] = None, + collect: bool = True +) -> Optional[Union[List[Row], DataFrame, SparkDataFrame]]: + """Run SQL query via databricks-connect or databricks-sql. + + Args: + query (str): sql query + method (str): select from dbsql and dbconnect + cursor (Optional[Cursor]): connection.cursor + spark (Optional[SparkSession]): spark session + collect (bool): whether to get the underlying data from spark dataframe + """ + if method == 'dbsql': + if cursor is None: + raise ValueError(f'cursor cannot be None if using method dbsql') + cursor.execute(query) + if collect: + return cursor.fetchall() + elif method == 'dbconnect': + if spark == None: + raise ValueError(f'sparkSession is required for dbconnect') + df = spark.sql(query) + if collect: + return df.collect() + return df + else: + raise ValueError(f'Unrecognized method: {method}') + + +def get_args(signed: List, json_output_folder: str, columns: List) -> Iterable: + for i, r in enumerate(signed): + yield (i, r.url, json_output_folder, columns) + + +def download(ipart: int, + url: str, + json_output_folder: str, + columns: Optional[List] = None, + resp_format: str = 'arrow', + compressed: bool = False) -> None: + """Thread download presigned url and save to jsonl locally. + + Args: + ipart (int): presigned url id + url (str): presigned url + json_output_folder (str): directory to save the ipart_th segment of dataframe + columns (list): schema to save to json + resp_format (str): whether to use arrow or json when collect + compressed (bool): if data is compressed before downloading. Need decompress if compressed=True. + """ + resp = requests.get(url) + if resp.status_code == 200: + if resp_format == 'json': + data = resp.json() + pd.DataFrame(data, columns=columns).to_json(os.path.join( + json_output_folder, 'part_' + str(ipart) + '.jsonl'), + orient='records', + lines=True) + return + + # When resp_format is arrow: + if compressed: + # The data is lz4 compressed arrow format. + # Decompress the data + decompressed_data = lz4.frame.decompress(resp.content) + # Convert the decompressed data into a PyArrow table + reader = pa.ipc.open_stream(decompressed_data) + else: + reader = pa.ipc.open_stream(resp.content) + table = reader.read_all() + + # Convert the PyArrow table into a pandas DataFrame + df = table.to_pandas() + df.to_json(os.path.join(json_output_folder, + 'part_' + str(ipart) + '.jsonl'), + orient='records', + lines=True, + force_ascii=False) + + +def download_starargs(args: Tuple) -> None: + return download(*args) + + +def fetch_data(method: str, cursor: Optional[Cursor], + sparkSession: Optional[SparkSession], start: int, end: int, + order_by: str, tablename: str, columns_str: str, + json_output_folder: str) -> None: + """Fetches a specified range of rows from a given table to a json file. + + This function executes a SQL query to retrieve a range of rows, determined by 'start' and 'end' indexes, + from a specified table and column set. The fetched data is then exported as a JSON file. + + Args: + method (str): The method to use for fetching data, either 'dbconnect' or 'dbsql'. + cursor (Optional[Cursor]): The cursor object for executing queries in 'dbsql' method. + sparkSession (Optional[SparkSession]): The Spark session object for executing queries in 'dbconnect' method. + start (int): The starting index for row fetching. + end (int): The ending index for row fetching. + order_by (str): The column name to use for ordering the rows. + tablename (str): The name of the table from which to fetch the data. + columns_str (str): The string representation of the columns to select from the table. + json_output_folder (str): The file path where the resulting JSON file will be saved. + + Returns: + None: The function doesn't return any value, but writes the result to a JSONL file. + """ + query = f""" + WITH NumberedRows AS ( + SELECT + *, + ROW_NUMBER() OVER (ORDER BY {order_by}) AS rn + FROM + {tablename} + ) + SELECT {columns_str} + FROM NumberedRows + WHERE rn BETWEEN {start+1} AND {end}""" + + if method == 'dbconnect': + spark_df = run_query(query, method, cursor, sparkSession, collect=False) + if spark_df is None: + raise RuntimeError( + f'Expect spark dataframe with {query} but got None') + pdf = spark_df.toPandas() # pyright: ignore + else: # method == 'dbsql': + ans = run_query(query, method, cursor, sparkSession, collect=True) + if ans is None: + raise RuntimeError(f'Got empty results with {query}') + records = [r.asDict() for r in ans] # pyright: ignore + pdf = pd.DataFrame.from_dict(records) + + pdf.to_json(os.path.join(json_output_folder, f'part_{start+1}_{end}.jsonl'), + orient='records', + lines=True) + + +def fetch( + method: str, + tablename: str, + json_output_folder: str, + batch_size: int = 1 << 30, + processes: int = 1, + sparkSession: Optional[SparkSession] = None, + dbsql: Optional[Connection] = None, +) -> None: + """Fetch UC delta table with databricks-connnect as JSONL. + + Args: + method (str): dbconnect or dbsql + tablename (str): catalog.scheme.tablename on UC + json_output_folder (str): path to write the result json file to + batch_size (int): number of rows that dbsql fetches each time to avoid OOM + processes (int): max number of processes to use to parallelize the fetch + sparkSession (pyspark.sql.sparksession): spark session + dbsql (databricks.sql.connect): dbsql session + """ + cursor = dbsql.cursor() if dbsql is not None else None + + try: + ans = run_query(f'SELECT COUNT(*) FROM {tablename}', method, cursor, + sparkSession) + nrows = [row.asDict() for row in ans][0].popitem()[1] # pyright: ignore + log.info(f'total_rows = {nrows}') + except Exception as e: + raise RuntimeError( + f'Error in get total rows from {tablename}. Restart sparkSession and try again' + ) from e + + try: + ans = run_query(f'SHOW COLUMNS IN {tablename}', method, cursor, + sparkSession) + columns = [row.asDict().popitem()[1] for row in ans] # pyright: ignore + order_by = columns[0] + columns_str = ','.join(columns) + log.info(f'order by column {order_by}') + except Exception as e: + raise RuntimeError( + f'Error in get columns from {tablename}. Restart sparkSession and try again' + ) from e + + if method == 'dbconnect' and sparkSession is not None: + log.info('processes = ', processes) + df = sparkSession.table(tablename) + + # Running the query and collecting the data as arrow or json. + signed, _, _ = df.collect_cf('arrow') # pyright: ignore + log.info(f'len(signed) = {len(signed)}') + + args = get_args(signed, json_output_folder, columns) + + # Stopping the SparkSession to avoid spilling connection state into the subprocesses. + sparkSession.stop() + + with ProcessPoolExecutor(max_workers=processes) as executor: + list(executor.map(download_starargs, args)) + + elif method == 'dbsql' and cursor is not None: + for start in range(0, nrows, batch_size): + log.warning(f'batch {start}') + end = min(start + batch_size, nrows) + fetch_data(method, cursor, sparkSession, start, end, order_by, + tablename, columns_str, json_output_folder) + + if cursor is not None: + cursor.close() + + +def validate_and_get_cluster_info(cluster_id: str, + databricks_host: str, + databricks_token: str, + http_path: Optional[str], + use_serverless: bool = False) -> tuple: + """Validate and get cluster info for running the Delta to JSONL conversion. + + Args: + cluster_id (str): cluster id to validate and fetch additional info for + databricks_host (str): databricks host name + databricks_token (str): databricks auth token + http_path (Optional[str]): http path to use for sql connect + use_serverless (bool): whether to use serverless or not + """ + method = 'dbsql' + dbsql = None + sparkSession = None + + if use_serverless: + method = 'dbconnect' + else: + w = WorkspaceClient() + res = w.clusters.get(cluster_id=cluster_id) + if res is None: + raise ValueError( + f'Cluster id {cluster_id} does not exist. Check cluster id and try again!' + ) + stripped_runtime = re.sub( + r'[a-zA-Z]', '', + res.spark_version.split('-scala')[0].replace('x-snapshot', '')) + runtime_version = re.sub(r'[.-]*$', '', stripped_runtime) + if version.parse(runtime_version) < version.parse( + MINIMUM_SQ_CONNECT_DBR_VERSION): + raise ValueError( + f'The minium DBR version required is {MINIMUM_SQ_CONNECT_DBR_VERSION} but got {version.parse(runtime_version)}' + ) + + if http_path is None and version.parse( + runtime_version) >= version.parse( + MINIMUM_DB_CONNECT_DBR_VERSION): + method = 'dbconnect' + + if method == 'dbconnect': + try: + if use_serverless: + session_id = str(uuid4()) + sparkSession = DatabricksSession.builder.host( + databricks_host).token(databricks_token).header( + 'x-databricks-session-id', session_id).getOrCreate() + + else: + sparkSession = DatabricksSession.builder.remote( + host=databricks_host, + token=databricks_token, + cluster_id=cluster_id).getOrCreate() + + except Exception as e: + raise RuntimeError( + 'Failed to create databricks connection. Check hostname and access token!' + ) from e + else: + try: + dbsql = sql.connect( + server_hostname=re.compile(r'^https?://').sub( + '', databricks_host).strip( + ), # sqlconnect hangs if hostname starts with https + http_path=http_path, + access_token=databricks_token, + ) + except Exception as e: + raise RuntimeError( + 'Failed to create sql connection to db workspace. To use sql connect, you need to provide http_path and cluster_id!' + ) from e + return method, dbsql, sparkSession + + +def fetch_DT(args: Namespace) -> None: + """Fetch UC Delta Table to local as jsonl.""" + log.info(f'Start .... Convert delta to json') + + obj = urllib.parse.urlparse(args.json_output_folder) + if obj.scheme != '': + raise ValueError( + f'Check the json_output_folder and verify it is a local path!') + + if os.path.exists(args.json_output_folder): + if not os.path.isdir(args.json_output_folder) or os.listdir( + args.json_output_folder): + raise RuntimeError( + f'A file or a folder {args.json_output_folder} already exists and is not empty. Remove it and retry!' + ) + + os.makedirs(args.json_output_folder, exist_ok=True) + + if not args.json_output_filename.endswith('.jsonl'): + raise ValueError('json_output_filename needs to be a jsonl file') + + log.info(f'Directory {args.json_output_folder} created.') + + method, dbsql, sparkSession = validate_and_get_cluster_info( + cluster_id=args.cluster_id, + databricks_host=args.DATABRICKS_HOST, + databricks_token=args.DATABRICKS_TOKEN, + http_path=args.http_path, + use_serverless=args.use_serverless) + + fetch(method, args.delta_table_name, args.json_output_folder, + args.batch_size, args.processes, sparkSession, dbsql) + + if dbsql is not None: + dbsql.close() + + # combine downloaded jsonl into one big jsonl for IFT + iterative_combine_jsons( + args.json_output_folder, + os.path.join(args.json_output_folder, args.json_output_filename)) + + +if __name__ == '__main__': + parser = ArgumentParser( + description= + 'Download delta table from UC and convert to json to save local') + parser.add_argument('--delta_table_name', + required=True, + type=str, + help='UC table ..') + parser.add_argument('--json_output_folder', + required=True, + type=str, + help='Local path to save the converted json') + parser.add_argument('--http_path', + required=False, + type=str, + help='http_path is set then dbsql method is used') + parser.add_argument('--batch_size', + required=False, + type=int, + default=1 << 30, + help='row chunks to transmit a time to avoid OOM') + parser.add_argument('--processes', + required=False, + type=int, + default=os.cpu_count(), + help='number of processes allowed to use') + parser.add_argument( + '--cluster_id', + required=False, + type=str, + help= + 'cluster id has runtime newer than 14.1.0 and access mode of either assigned or shared can use databricks-connect.' + ) + parser.add_argument( + '--use_serverless', + required=False, + type=bool, + default=False, + help= + 'Use serverless or not. Make sure the workspace is entitled with serverless' + ) + parser.add_argument( + '--json_output_filename', + required=False, + type=str, + default='train-00000-of-00001.jsonl', + help= + 'The name of the combined final jsonl that combines all partitioned jsonl' + ) + args = parser.parse_args() + + from databricks.sdk import WorkspaceClient + w = WorkspaceClient() + args.DATABRICKS_HOST = w.config.host + args.DATABRICKS_TOKEN = w.config.token + + tik = time.time() + fetch_DT(args) + log.info('Elapsed time', time.time() - tik) diff --git a/scripts/data_prep/convert_text_to_mds.py b/scripts/data_prep/convert_text_to_mds.py index dc7c514d75..bfd60b8ee1 100644 --- a/scripts/data_prep/convert_text_to_mds.py +++ b/scripts/data_prep/convert_text_to_mds.py @@ -47,18 +47,21 @@ def parse_args() -> Namespace: '--compression', type=str, default='zstd', + required=False, help='The compression algorithm to use for MDS writing', ) parser.add_argument( '--concat_tokens', type=int, + required=True, help='Convert text to tokens and concatenate up to this many tokens', ) parser.add_argument( '--tokenizer', type=str, + required=True, help='The name of the tokenizer to use', ) parser.add_argument( @@ -77,6 +80,13 @@ def parse_args() -> Namespace: help= 'The text to append to each example to separate concatenated examples', ) + parser.add_argument( + '--use_tokenizer_eos', + required=False, + action='store_true', + default=False, + help='Use the EOS text from the tokenizer.', + ) parser.add_argument( '--no_wrap', default=False, @@ -103,11 +113,15 @@ def parse_args() -> Namespace: parsed = parser.parse_args() - # Make sure we have needed concat options - if (parsed.concat_tokens is not None and - isinstance(parsed.concat_tokens, int) and parsed.tokenizer is None): - parser.error( - 'When setting --concat_tokens, you must specify a --tokenizer') + # Set eos token. + if parsed.use_tokenizer_eos: + # Ensure that eos text is not specified twice. + if parsed.eos_text is not None: + parser.error( + 'Cannot set --eos_text with --use_tokenizer_eos. Please specify one.' + ) + tokenizer = AutoTokenizer.from_pretrained(parsed.tokenizer) + parsed.eos_text = tokenizer.eos_token # now that we have validated them, change BOS/EOS to strings if parsed.bos_text is None: @@ -371,6 +385,10 @@ def convert_text_to_mds( local_output_folder = tempfile.TemporaryDirectory( ).name if is_remote_output else output_folder + if os.path.isdir(output_folder) and len(os.listdir(output_folder)) > 0: + raise FileExistsError( + f'{output_folder=} is not empty. Please remove or empty it.') + if processes > 1: # Download and convert the text files in parallel args = get_task_args(object_names, local_output_folder, input_folder, diff --git a/scripts/misc/download_hf_model.py b/scripts/misc/download_hf_model.py deleted file mode 100644 index 58c3445e7d..0000000000 --- a/scripts/misc/download_hf_model.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright 2022 MosaicML LLM Foundry authors -# SPDX-License-Identifier: Apache-2.0 - -"""Script to download model weights from Hugging Face Hub or a cache server.""" -import argparse -import logging -import os -import sys - -from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE - -from llmfoundry.utils.model_download_utils import (download_from_cache_server, - download_from_hf_hub) - -HF_TOKEN_ENV_VAR = 'HUGGING_FACE_HUB_TOKEN' - -logging.basicConfig(format=f'%(asctime)s: %(levelname)s: %(name)s: %(message)s', - level=logging.INFO) -log = logging.getLogger(__name__) - -if __name__ == '__main__': - argparser = argparse.ArgumentParser() - argparser.add_argument('--model', type=str, required=True) - argparser.add_argument('--download-from', - type=str, - choices=['hf', 'cache'], - default='hf') - argparser.add_argument('--token', - type=str, - default=os.getenv(HF_TOKEN_ENV_VAR)) - argparser.add_argument('--save-dir', - type=str, - default=HUGGINGFACE_HUB_CACHE) - argparser.add_argument('--cache-url', type=str, default=None) - argparser.add_argument('--ignore-cert', action='store_true', default=False) - argparser.add_argument( - '--fallback', - action='store_true', - default=True, - help= - 'Whether to fallback to downloading from Hugging Face if download from cache fails', - ) - - args = argparser.parse_args(sys.argv[1:]) - if args.download_from == 'hf': - download_from_hf_hub(args.model, - save_dir=args.save_dir, - token=args.token) - else: - try: - download_from_cache_server( - args.model, - args.cache_url, - args.save_dir, - token=args.token, - ignore_cert=args.ignore_cert, - ) - - # A little hacky: run the Hugging Face download just to repair the symlinks in the HF cache file structure. - # This shouldn't actually download any files if the cache server download was successful, but should address - # a non-deterministic bug where the symlinks aren't repaired properly by the time the model is initialized. - log.info('Repairing Hugging Face cache symlinks') - - # Hide some noisy logs that aren't important for just the symlink repair. - old_level = logging.getLogger().level - logging.getLogger().setLevel(logging.ERROR) - download_from_hf_hub(args.model, - save_dir=args.save_dir, - token=args.token) - logging.getLogger().setLevel(old_level) - - except PermissionError: - log.error(f'Not authorized to download {args.model}.') - except Exception as e: - if args.fallback: - log.warning( - f'Failed to download {args.model} from cache server. Falling back to Hugging Face Hub. Error: {e}' - ) - download_from_hf_hub(args.model, - save_dir=args.save_dir, - token=args.token) - else: - raise e diff --git a/scripts/misc/download_model.py b/scripts/misc/download_model.py new file mode 100644 index 0000000000..13a63ce55e --- /dev/null +++ b/scripts/misc/download_model.py @@ -0,0 +1,127 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Script to download model weights from Hugging Face Hub or a cache server. + +Download from Hugging Face Hub: + python download_model.py hf --model mosaicml/mpt-7b --save-dir --token + +Download from ORAS registry: + python download_model.py oras --model mosaicml/mpt-7b --config-file \ + --credentials-dir --save-dir + +Download from an HTTP file server: + python download_model.py http --url https://server.com/models/mosaicml/mpt-7b/ --save-dir + +Download from an HTTP file server with fallback to Hugging Face Hub: + python download_model.py http --host https://server.com --path mosaicml/mpt-7b --save-dir \ + fallback-hf --model mosaicml/mpt-7b --token hf_token +""" +import argparse +import logging +import os + +from llmfoundry.utils.model_download_utils import ( + download_from_hf_hub, download_from_http_fileserver, download_from_oras) + +HF_TOKEN_ENV_VAR = 'HUGGING_FACE_HUB_TOKEN' + +logging.basicConfig(format=f'%(asctime)s: %(levelname)s: %(name)s: %(message)s', + level=logging.INFO) +log = logging.getLogger(__name__) + + +def add_hf_parser_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument('--model', type=str, required=True) + parser.add_argument('--prefer-safetensors', type=bool, default=True) + parser.add_argument('--token', + type=str, + default=os.getenv(HF_TOKEN_ENV_VAR)) + + +def add_oras_parser_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument('--model', type=str, required=True) + parser.add_argument('--config-file', type=str, required=True) + parser.add_argument('--credentials-dir', type=str, required=True) + parser.add_argument('--concurrency', type=int, default=10) + + +def add_http_parser_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument('--url', type=str, required=True) + parser.add_argument('--ignore-cert', action='store_true', default=False) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest='download_from', required=True) + + base_parser = argparse.ArgumentParser(add_help=False) + base_parser.add_argument('--save-dir', type=str, required=True) + base_parser.add_argument('--tokenizer-only', + default=False, + action='store_true') + + # Add subparser for downloading from Hugging Face Hub. + hf_parser = subparsers.add_parser('hf', parents=[base_parser]) + add_hf_parser_arguments(hf_parser) + + # Add subparser for downloading from ORAS registry. + oras_parser = subparsers.add_parser('oras', parents=[base_parser]) + add_oras_parser_arguments(oras_parser) + + # Add subparser for downloading from an HTTP file server. + http_parser = subparsers.add_parser('http', parents=[base_parser]) + add_http_parser_arguments(http_parser) + + # Add fallbacks for HTTP + fallback_subparsers = http_parser.add_subparsers(dest='fallback') + hf_fallback_parser = fallback_subparsers.add_parser('fallback-hf') + add_hf_parser_arguments(hf_fallback_parser) + + oras_fallback_parser = fallback_subparsers.add_parser('fallback-oras') + add_oras_parser_arguments(oras_fallback_parser) + + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_args() + download_from = args.download_from + + if download_from == 'http': + if args.tokenizer_only: + raise ValueError( + 'tokenizer-only is not currently supported for http.') + try: + download_from_http_fileserver(args.url, args.save_dir, + args.ignore_cert) + except PermissionError as e: + log.error(f'Not authorized to download {args.model}.') + raise e + except Exception as e: + log.warning(f'Failed to download from HTTP server with error: {e}') + if args.fallback: + log.warning(f'Falling back to provided fallback destination.') + if args.fallback == 'fallback-hf': + download_from = 'hf' + elif args.fallback == 'fallback-oras': + download_from = 'oras' + else: + raise ValueError( + f'Invalid fallback destination {args.fallback}.') + else: + raise e + + if download_from == 'hf': + download_from_hf_hub(args.model, + save_dir=args.save_dir, + token=args.token, + tokenizer_only=args.tokenizer_only, + prefer_safetensors=args.prefer_safetensors) + elif download_from == 'oras': + download_from_oras(args.model, + args.config_file, + args.credentials_dir, + args.save_dir, + tokenizer_only=args.tokenizer_only, + concurrency=args.concurrency) diff --git a/scripts/train/benchmarking/README.md b/scripts/train/benchmarking/README.md index c3c8bc1c74..5414cdc7bf 100644 --- a/scripts/train/benchmarking/README.md +++ b/scripts/train/benchmarking/README.md @@ -69,6 +69,29 @@ Our microbatching engine enables microbatch sizes that do not divde Global Batch [comment]: # (TODO: Update tables with torch 2.0 after next Composer release) +## H100 80GB BF16 (Large Scale, >= 128 GPUs) +| Model | SeqLen (T) | # GPUs | GPU | MFU | HFU | Model TFLOP | MicroBatchSize | GradAccum | GlobalBatchSize | Throughput (S/s) | Throughput (T/s) | Throughput (T/s/GPU) | GlobalBatchSize (T) | Precision | MP Mode | Sharding Strategy | Activation Checkpointing | Activation CPUOffload | NumParams | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| 70b | 2048 | 512 | h100_80gb | 41.25 | 55.0 | 408 | 8 | 1 | 4096 | 251 | 515636 | 1007 | 8388608 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 64862437376 | +| 70b | 2048 | 256 | h100_80gb | 42.42 | 56.56 | 419 | 8 | 1 | 2048 | 129 | 265149 | 1035 | 4194304 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 64862437376 | +| 70b | 2048 | 128 | h100_80gb | 43.36 | 57.81 | 428 | 8 | 1 | 1024 | 66 | 135490 | 1058 | 2097152 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 64862437376 | +| 30b | 2048 | 512 | h100_80gb | 40.27 | 53.69 | 398 | 8 | 1 | 4096 | 528 | 1083366 | 2115 | 8388608 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 29975214080 | +| 30b | 2048 | 256 | h100_80gb | 40.89 | 54.52 | 404 | 8 | 1 | 2048 | 268 | 550022 | 2148 | 4194304 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 29975214080 | +| 30b | 2048 | 128 | h100_80gb | 41.85 | 55.8 | 414 | 8 | 1 | 1024 | 137 | 281491 | 2199 | 2097152 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 29975214080 | +| 13b | 2048 | 512 | h100_80gb | 41.12 | 54.83 | 406 | 16 | 1 | 8192 | 1238 | 2535811 | 4952 | 16777216 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 12853954560 | +| 13b | 2048 | 256 | h100_80gb | 41.42 | 55.23 | 409 | 16 | 1 | 4096 | 623 | 1277214 | 4989 | 8388608 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 12853954560 | +| 13b | 2048 | 128 | h100_80gb | 42.18 | 56.24 | 417 | 16 | 1 | 2048 | 317 | 650264 | 5080 | 4194304 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 12853954560 | +| 7b | 2048 | 512 | h100_80gb | 42.2 | 42.2 | 417 | 6 | 1 | 3072 | 2417 | 4951479 | 9670 | 6291456 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 6658859008 | +| 7b | 2048 | 256 | h100_80gb | 44.15 | 44.15 | 436 | 6 | 1 | 1536 | 1264 | 2590548 | 10119 | 3145728 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 6658859008 | +| 7b | 2048 | 128 | h100_80gb | 45.71 | 45.71 | 452 | 6 | 1 | 768 | 654 | 1340830 | 10475 | 1572864 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 6658859008 | +| 3b | 2048 | 512 | h100_80gb | 39.24 | 39.24 | 388 | 8 | 1 | 4096 | 5416 | 11092218 | 21664 | 8388608 | amp_bf16 | DEFAULT | SHARD_GRAD_OP | False | False | 2651837440 | +| 3b | 2048 | 256 | h100_80gb | 41.25 | 41.25 | 408 | 8 | 1 | 2048 | 2846 | 5829686 | 22772 | 4194304 | amp_bf16 | DEFAULT | SHARD_GRAD_OP | False | False | 2651837440 | +| 3b | 2048 | 128 | h100_80gb | 42.43 | 42.43 | 419 | 8 | 1 | 1024 | 1463 | 2998098 | 23422 | 2097152 | amp_bf16 | DEFAULT | SHARD_GRAD_OP | False | False | 2651837440 | +| 1b | 2048 | 512 | h100_80gb | 36.65 | 36.65 | 362 | 12 | 1 | 6144 | 9959 | 20396905 | 39837 | 12582912 | amp_bf16 | DEFAULT | SHARD_GRAD_OP | False | False | 1315950592 | +| 1b | 2048 | 256 | h100_80gb | 39.15 | 39.15 | 387 | 12 | 1 | 3072 | 5319 | 10894207 | 42555 | 6291456 | amp_bf16 | DEFAULT | SHARD_GRAD_OP | False | False | 1315950592 | +| 1b | 2048 | 128 | h100_80gb | 40.6 | 40.6 | 401 | 12 | 1 | 1536 | 2757 | 5647854 | 44123 | 3145728 | amp_bf16 | DEFAULT | SHARD_GRAD_OP | False | False | 1315950592 | + + ## H100 80GB BF16 | Model | SeqLen (T) | # GPUs | GPU | MFU | HFU | Model TFLOP | MicroBatchSize | GradAccum | GlobalBatchSize | Throughput (S/s) | Throughput (T/s) | Throughput (T/s/GPU) | GlobalBatchSize (T) | Precision | MP Mode | Sharding Strategy | Activation Checkpointing | Activation CPUOffload | NumParams | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | diff --git a/scripts/train/benchmarking/submit_benchmarks.py b/scripts/train/benchmarking/submit_benchmarks.py index bfff10165a..a020745581 100644 --- a/scripts/train/benchmarking/submit_benchmarks.py +++ b/scripts/train/benchmarking/submit_benchmarks.py @@ -376,7 +376,7 @@ def get_integrations(project: str, git_integration.update({ 'integration_type': 'git_repo', 'git_repo': 'mosaicml/llm-foundry', - 'pip_install': '-e .[gpu]' + 'pip_install': '.[gpu-flash2]' }) integrations = [git_integration] @@ -398,8 +398,8 @@ def run_config(config: Tuple[str, int, int, str, str, int, str], { 'integration_type': 'git_repo', 'git_repo': 'mosaicml/llm-foundry', - 'git_branch': 'v0.4.0', - 'pip_install': '-e .[gpu]', + 'git_branch': 'main', + 'pip_install': '.[gpu-flash2]', }, { 'integration_type': 'wandb', @@ -411,7 +411,7 @@ def run_config(config: Tuple[str, int, int, str, str, int, str], command = '' if gpu_type == 'h100_80gb' and 'fp8' in precision: # Required for flash-attn and FP8 training command += f""" - pip install flash-attn==1.0.7 --no-build-isolation + pip install flash-attn==2.4.2 --no-build-isolation pip install git+https://github.com/NVIDIA/TransformerEngine.git@v0.10 pip uninstall install pydantic --yes pip install pydantic==1.9.0 @@ -420,11 +420,11 @@ def run_config(config: Tuple[str, int, int, str, str, int, str], if args.data_remote is None: command += f""" cd llm-foundry/scripts - python data_prep/convert_dataset_hf.py --dataset c4 --data_subset en --out_root ./my-copy-c4 --splits train_small val_small --concat_tokens {max_seq_len} --tokenizer gpt2 --eos_text '<|endoftext|>' + python data_prep/convert_dataset_hf.py --dataset c4 --data_subset en --out_root ./my-copy-c4 --splits train_small val_small --concat_tokens {max_seq_len} --eos_text '<|endoftext|>' composer train/train.py /mnt/config/parameters.yaml """ else: - command = f""" + command += f""" cd llm-foundry/scripts composer train/train.py /mnt/config/parameters.yaml """ @@ -487,6 +487,7 @@ def run_config(config: Tuple[str, int, int, str, str, int, str], print(f'Launching run {run.name}') else: print(f'run = {name}') + print(f'{config=}') def run_check_capacity(model_yaml: str, diff --git a/scripts/train/benchmarking/sweep.py b/scripts/train/benchmarking/sweep.py new file mode 100644 index 0000000000..441c9825f8 --- /dev/null +++ b/scripts/train/benchmarking/sweep.py @@ -0,0 +1,91 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import os + +# Define the arguments to sweep over + +base_args = [ + '--project tput', + '--image ', + '--git_branch main', + '--precisions bf16', + '--fsdp_config_mixed_precision PURE', + '--fsdp_config_limit_all_gathers true', + '--fsdp_config_forward_prefetch true', + '--fsdp_config_backward_prefetch BACKWARD_PRE', + '--activation_cpu_offload false', + '--seq_len_exp 11 11', + '--accum 1', + '--clusters ', + '--gpu_types h100_80gb', + '--data_remote ', + '--wandb true', + '--priority lowest', + '--RUN true', +] + +num_gpu_args_list = [ + [ + '--gpu_nums 128', + ], + [ + '--gpu_nums 256', + ], + [ + '--gpu_nums 512', + ], +] + +model_args_list = [ + [ + '--model_yamls 1b.yaml', + '--fsdp_config_activation_checkpointing false', + '--fsdp_config_shard_strategy SHARD_GRAD_OP', + '--microbatch_size 12', + '--attn_impl flash', + ], + [ + '--model_yamls 3b.yaml', + '--fsdp_config_activation_checkpointing false', + '--fsdp_config_shard_strategy SHARD_GRAD_OP', + '--microbatch_size 8', + '--attn_impl flash', + ], + [ + '--model_yamls 7b.yaml', + '--fsdp_config_activation_checkpointing false', + '--fsdp_config_shard_strategy FULL_SHARD', + '--microbatch_size 6', + '--attn_impl flash', + ], + [ + '--model_yamls 13b.yaml', + '--fsdp_config_activation_checkpointing true', + '--fsdp_config_shard_strategy FULL_SHARD', + '--microbatch_size 16', + '--attn_impl triton', + ], + [ + '--model_yamls 30b.yaml', + '--fsdp_config_activation_checkpointing true', + '--fsdp_config_shard_strategy FULL_SHARD', + '--microbatch_size 8', + '--attn_impl triton', + ], + [ + '--model_yamls 70b.yaml', + '--fsdp_config_activation_checkpointing true', + '--fsdp_config_shard_strategy FULL_SHARD', + '--microbatch_size 8', + '--attn_impl flash', + ], +] + +# Iterate over the arguments and call submit_benchmarks.py +for num_gpu_args in num_gpu_args_list: + for model_args in model_args_list: + command = ['python submit_benchmarks.py' + ] + base_args + num_gpu_args + model_args + command = ' '.join(command) + os.system(command) diff --git a/scripts/train/train.py b/scripts/train/train.py index d1e60ffb26..a0c18beb2e 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -278,7 +278,8 @@ def main(cfg: DictConfig) -> Trainer: logger_configs: Optional[DictConfig] = pop_config(cfg, 'loggers', must_exist=False, - default_value=None) + default_value=None, + convert=True) callback_configs: Optional[DictConfig] = pop_config(cfg, 'callbacks', must_exist=False, @@ -438,13 +439,17 @@ def main(cfg: DictConfig) -> Trainer: format= f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s' ) - logging.getLogger('llmfoundry').setLevel(python_log_level.upper()) + logging.getLogger('llmfoundry').setLevel( + python_log_level.upper()) # Foundry module + logging.getLogger(__name__).setLevel( + python_log_level.upper()) # Train script # Initialize context init_context = process_init_device(model_config, fsdp_config) logged_cfg.update({'fsdp_config': fsdp_config}, merge=True) # Build tokenizer + log.info('Building tokenizer...') tokenizer_name = tokenizer_config['name'] tokenizer_kwargs = tokenizer_config.get('kwargs', {}) tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) @@ -640,6 +645,11 @@ def main(cfg: DictConfig) -> Trainer: if __name__ == '__main__': yaml_path, args_list = sys.argv[1], sys.argv[2:] + + # Disable resolving environment variables through omegaconf. + om.clear_resolver('oc.env') + + # Load yaml and cli arguments. with open(yaml_path) as f: yaml_cfg = om.load(f) cli_cfg = om.from_cli(args_list) diff --git a/setup.py b/setup.py index 8122bbb14f..511e665ed4 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + """MosaicML LLM Foundry package setup.""" import os @@ -49,10 +52,10 @@ install_requires = [ 'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.17.2,<0.18', 'accelerate>=0.25,<0.26', # for HF inference `device_map` - 'transformers>=4.36,<4.37', + 'transformers>=4.37,<4.38', 'mosaicml-streaming>=0.7.2,<0.8', 'torch>=2.1,<2.1.1', - 'datasets==2.15.0', + 'datasets>=2.16,<2.17', 'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data 'sentencepiece==0.1.97', 'einops==0.7.0', @@ -73,7 +76,7 @@ extra_deps = {} extra_deps['dev'] = [ - 'pre-commit>=2.18.1,<3', + 'pre-commit>=3.4.0,<4', 'pytest>=7.2.1,<8', 'pytest_codeblocks>=0.16.1,<0.17', 'pytest-cov>=4,<5', @@ -84,7 +87,10 @@ ] extra_deps['databricks'] = [ - 'mosaicml[databricks]>=0.17.2,<0.18', + 'mosaicml[databricks]>=0.17.1,<0.18', + 'databricks-sql-connector>=3,<4', + 'databricks-connect==14.1.0', + 'lz4>=4,<5', ] extra_deps['tensorboard'] = [ @@ -93,13 +99,13 @@ extra_deps['gpu'] = [ 'flash-attn==1.0.9', - 'mosaicml-turbo==0.0.7', + 'mosaicml-turbo==0.0.8', # PyPI does not support direct dependencies, so we remove this line before uploading from PyPI 'xentropy-cuda-lib@git+https://github.com/HazyResearch/flash-attention.git@v1.0.9#subdirectory=csrc/xentropy', ] extra_deps['gpu-flash2'] = [ - 'flash-attn==2.3.6', - 'mosaicml-turbo==0.0.7', + 'flash-attn==2.4.2', + 'mosaicml-turbo==0.0.8', ] extra_deps['peft'] = [ diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py new file mode 100644 index 0000000000..b366d8635a --- /dev/null +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -0,0 +1,305 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +# copyright 2022 mosaicml llm foundry authors +# spdx-license-identifier: apache-2.0 + +import unittest +from argparse import Namespace +from typing import Any +from unittest.mock import MagicMock, mock_open, patch + +from scripts.data_prep.convert_delta_to_json import (download, fetch_DT, + iterative_combine_jsons, + run_query) + + +class TestConverDeltaToJsonl(unittest.TestCase): + + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') + @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') + @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') + @patch('scripts.data_prep.convert_delta_to_json.fetch') + @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + def test_stream_delta_to_json(self, mock_workspace_client: Any, + mock_fetch: Any, mock_combine_jsons: Any, + mock_makedirs: Any, mock_sql_connect: Any): + + args = MagicMock() + args.delta_table_name = 'test_table' + args.json_output_folder = '/path/to/jsonl' + args.DATABRICKS_HOST = 'test_host' + args.DATABRICKS_TOKEN = 'test_token' + args.http_path = 'test_path' + args.batch_size = 1000 + args.partitions = 1 + args.cluster_id = '1234' + args.debug = False + args.use_serverless = False + args.json_output_filename = 'combined.jsonl' + + mock_cluster_get = MagicMock() + mock_cluster_get.return_value = MagicMock( + spark_version='14.1.0-scala2.12') + mock_workspace_client.return_value.clusters.get = mock_cluster_get + + fetch_DT(args) + mock_sql_connect.assert_called_once_with(server_hostname='test_host', + http_path='test_path', + access_token='test_token') + mock_makedirs.assert_called_once_with('/path/to/jsonl', exist_ok=True) + mock_fetch.assert_called_once() + mock_combine_jsons.assert_called_once_with( + '/path/to/jsonl', '/path/to/jsonl/combined.jsonl') + + @patch('scripts.data_prep.convert_delta_to_json.os.listdir') + @patch('builtins.open', + new_callable=mock_open, + read_data='{"key": "value"}') + def test_iterative_combine_jsons(self, mock_file: Any, mock_listdir: Any): + mock_listdir.return_value = ['file1.jsonl', 'file2.jsonl'] + json_directory = '/fake/dir' + output_file = '/fake/output.jsonl' + + iterative_combine_jsons(json_directory, output_file) + + mock_listdir.assert_called_once_with(json_directory) + mock_file.assert_called() + """ + Diagnostic print + for call_args in mock_file().write.call_args_list: + print(call_args) + -------------------- + call('{') + call('"key"') + call(': ') + call('"value"') + call('}') + call('\n') + call('{') + call('"key"') + call(': ') + call('"value"') + call('}') + call('\n') + -------------------- + """ + self.assertEqual(mock_file().write.call_count, 2) + + @patch('scripts.data_prep.convert_delta_to_json.SparkSession') + def test_run_query_dbconnect(self, mock_spark: Any): + method = 'dbconnect' + mock_cursor = None + mock_spark.sql.return_value.collect.return_value = 'result' + + result = run_query('SELECT * FROM table', + method, + cursor=mock_cursor, + spark=mock_spark) + + mock_spark.sql.assert_called_once_with('SELECT * FROM table') + self.assertEqual(result, 'result') + + @patch('scripts.data_prep.convert_delta_to_json.Cursor') + def test_run_query_dbsql(self, mock_cursor: Any): + method = 'dbsql' + mock_cursor.fetchall.return_value = 'result' + mock_spark = None + + result = run_query('SELECT * FROM table', + method, + cursor=mock_cursor, + spark=mock_spark) + + mock_cursor.execute.assert_called_once_with('SELECT * FROM table') + self.assertEqual(result, 'result') + + @patch('scripts.data_prep.convert_delta_to_json.requests.get') + @patch('scripts.data_prep.convert_delta_to_json.pd.DataFrame.to_json') + @patch('scripts.data_prep.convert_delta_to_json.os.path.join', + return_value='/fake/path/part_1.jsonl') + @patch('scripts.data_prep.convert_delta_to_json.time.sleep' + ) # Mock sleep to speed up the test + def test_download_success(self, mock_sleep: Any, mock_join: Any, + mock_to_json: Any, mock_get: Any): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = [['val1.1', 'val1.2'], + ['val2.1', 'val2.2']] + mock_get.return_value = mock_response + + download(1, + 'http://fakeurl.com/data', + '/fake/path', ['A', 'B'], + resp_format='json') + + mock_get.assert_called_with('http://fakeurl.com/data') + mock_join.assert_called_with('/fake/path', 'part_1.jsonl') + mock_to_json.assert_called_with('/fake/path/part_1.jsonl', + orient='records', + lines=True) + + mock_get.assert_called_once_with('http://fakeurl.com/data') + + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') + @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') + @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') + @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') + @patch('scripts.data_prep.convert_delta_to_json.fetch') + def test_dbconnect_called(self, mock_fetch: Any, mock_combine_jsons: Any, + mock_makedirs: Any, mock_workspace_client: Any, + mock_databricks_session: Any, + mock_sql_connect: Any): + + args = MagicMock() + + args.delta_table_name = 'test_table' + args.json_output_folder = '/path/to/jsonl' + # Execute function with http_path=None (should use dbconnect) + args.http_path = None + args.cluster_id = '1234' + args.DATABRICKS_HOST = 'host' + args.DATABRICKS_TOKEN = 'token' + args.use_serverless = False + + mock_cluster_response = Namespace(spark_version='14.1.0-scala2.12') + mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response + + mock_remote = MagicMock() + mock_remote.getOrCreate.return_value = MagicMock( + ) # Mock return value for getOrCreate + mock_databricks_session.builder.remote.return_value = mock_remote + + fetch_DT(args) + mock_databricks_session.builder.remote.assert_called_once_with( + host=args.DATABRICKS_HOST, + token=args.DATABRICKS_TOKEN, + cluster_id=args.cluster_id) + + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') + @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') + @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') + @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') + @patch('scripts.data_prep.convert_delta_to_json.fetch') + def test_sqlconnect_called_dbr13(self, mock_fetch: Any, + mock_combine_jsons: Any, + mock_makedirs: Any, + mock_workspace_client: Any, + mock_databricks_session: Any, + mock_sql_connect: Any): + + args = MagicMock() + + args.delta_table_name = 'test_table' + args.json_output_folder = '/path/to/jsonl' + # Execute function with http_path=None (should use dbconnect) + args.http_path = 'test_path' + args.cluster_id = '1234' + args.DATABRICKS_HOST = 'host' + args.DATABRICKS_TOKEN = 'token' + args.use_serverless = False + + mock_cluster_response = Namespace(spark_version='13.0.0-scala2.12') + mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response + + fetch_DT(args) + mock_sql_connect.assert_called_once_with( + server_hostname=args.DATABRICKS_HOST, + http_path=args.http_path, + access_token=args.DATABRICKS_TOKEN) + + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') + @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') + @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') + @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') + @patch('scripts.data_prep.convert_delta_to_json.fetch') + def test_sqlconnect_called_dbr14(self, mock_fetch: Any, + mock_combine_jsons: Any, + mock_makedirs: Any, + mock_workspace_client: Any, + mock_databricks_session: Any, + mock_sql_connect: Any): + + args = MagicMock() + + args.delta_table_name = 'test_table' + args.json_output_folder = '/path/to/jsonl' + # Execute function with http_path=None (should use dbconnect) + args.http_path = 'test_path' + args.cluster_id = '1234' + args.DATABRICKS_HOST = 'host' + args.DATABRICKS_TOKEN = 'token' + args.use_serverless = False + + mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') + mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response + + fetch_DT(args) + mock_sql_connect.assert_called_once_with( + server_hostname=args.DATABRICKS_HOST, + http_path=args.http_path, + access_token=args.DATABRICKS_TOKEN) + + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') + @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') + @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') + @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') + @patch('scripts.data_prep.convert_delta_to_json.fetch') + def test_sqlconnect_called_https(self, mock_fetch: Any, + mock_combine_jsons: Any, + mock_makedirs: Any, + mock_workspace_client: Any, + mock_databricks_session: Any, + mock_sql_connect: Any): + + args = MagicMock() + + args.delta_table_name = 'test_table' + args.json_output_folder = '/path/to/jsonl' + # Execute function with http_path=None (should use dbconnect) + args.http_path = 'test_path' + args.cluster_id = '1234' + args.DATABRICKS_HOST = 'https://test-host' + args.DATABRICKS_TOKEN = 'token' + args.use_serverless = False + + mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') + mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response + + fetch_DT(args) + mock_sql_connect.assert_called_once_with( + server_hostname='test-host', + http_path=args.http_path, + access_token=args.DATABRICKS_TOKEN) + + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') + @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') + @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') + @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') + @patch('scripts.data_prep.convert_delta_to_json.fetch') + def test_serverless(self, mock_fetch: Any, mock_combine_jsons: Any, + mock_makedirs: Any, mock_workspace_client: Any, + mock_databricks_session: Any, mock_sql_connect: Any): + + args = MagicMock() + + args.delta_table_name = 'test_table' + args.json_output_folder = '/path/to/jsonl' + # Execute function with http_path=None (should use dbconnect) + args.http_path = 'test_path' + args.cluster_id = '1234' + args.DATABRICKS_HOST = 'https://test-host' + args.DATABRICKS_TOKEN = 'token' + args.use_serverless = True + + mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') + mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response + + fetch_DT(args) + assert not mock_sql_connect.called + assert not mock_databricks_session.builder.remote.called diff --git a/tests/a_scripts/data_prep/test_convert_text_to_mds.py b/tests/a_scripts/data_prep/test_convert_text_to_mds.py index cc293a2cdd..3a00a8889f 100644 --- a/tests/a_scripts/data_prep/test_convert_text_to_mds.py +++ b/tests/a_scripts/data_prep/test_convert_text_to_mds.py @@ -3,6 +3,7 @@ import os import pathlib +import shutil from concurrent.futures import ProcessPoolExecutor from glob import glob from typing import Callable, Iterable, List @@ -55,23 +56,6 @@ def upload_object(self, object_name: str, filename: str): remote_file.write(local_file.read()) -def _call_convert_text_to_mds(processes: int, tokenizer_name: str, - concat_tokens: int) -> None: - convert_text_to_mds( - tokenizer_name=tokenizer_name, - output_folder=f's3://fake-test-output-path', - input_folder=f's3://fake-test-input-path', - concat_tokens=concat_tokens, - eos_text='', - bos_text='', - no_wrap=False, - compression='zstd', - processes=processes, - args_str='Namespace()', - reprocess=False, - ) - - # Mock starmap with no multiprocessing def _mock_map(func: Callable, args: Iterable) -> Iterable: for arg in args: @@ -107,9 +91,22 @@ def test_single_and_multi_process(merge_shard_groups: Mock, maybe_create_object_store_from_uri.return_value = mock_object_store parse_uri.return_value = ('s3', 'fake-test-bucket', str(remote_folder)) - _call_convert_text_to_mds(processes=processes, - tokenizer_name=tokenizer_name, - concat_tokens=concat_tokens) + def call_convert_text_to_mds() -> None: + convert_text_to_mds( + tokenizer_name=tokenizer_name, + output_folder=f's3://fake-test-output-path', + input_folder=f's3://fake-test-input-path', + concat_tokens=concat_tokens, + eos_text='', + bos_text='', + no_wrap=False, + compression='zstd', + processes=processes, + args_str='Namespace()', + reprocess=False, + ) + + call_convert_text_to_mds() # Check call counts assert download_and_convert.call_count == processes # called once per process @@ -131,9 +128,7 @@ def test_single_and_multi_process(merge_shard_groups: Mock, _assert_files_exist(prefix=remote_folder, files=['index.json', DONE_FILENAME] + shards) - _call_convert_text_to_mds(processes=processes, - tokenizer_name=tokenizer_name, - concat_tokens=concat_tokens) + call_convert_text_to_mds() # Check call counts assert download_and_convert.call_count == processes # No changes because we shoudn't reprocess @@ -146,9 +141,7 @@ def test_single_and_multi_process(merge_shard_groups: Mock, mock_object_store = Mock(wraps=object_store) maybe_create_object_store_from_uri.return_value = mock_object_store - _call_convert_text_to_mds(processes=processes, - tokenizer_name=tokenizer_name, - concat_tokens=concat_tokens) + call_convert_text_to_mds() # Check call counts assert download_and_convert.call_count == processes * 2 # called once per process @@ -187,31 +180,42 @@ def test_local_path(tmp_path: pathlib.Path): input_folder = tmp_path / 'input' output_folder = tmp_path / 'output' + def call_convert_text_to_mds(reprocess: bool): + convert_text_to_mds( + tokenizer_name='mosaicml/mpt-7b', + output_folder=str(output_folder), + input_folder=str(input_folder), + concat_tokens=1, + eos_text='', + bos_text='', + no_wrap=False, + compression='zstd', + processes=1, + args_str='Namespace()', + reprocess=reprocess, + ) + # Create input text data os.makedirs(input_folder, exist_ok=True) with open(input_folder / 'test.txt', 'w') as f: f.write('test') # Convert text data to mds - convert_text_to_mds( - tokenizer_name='mosaicml/mpt-7b', - output_folder=str(output_folder), - input_folder=str(input_folder), - concat_tokens=1, - eos_text='', - bos_text='', - no_wrap=False, - compression='zstd', - processes=1, - args_str='Namespace()', - reprocess=False, - ) + call_convert_text_to_mds(reprocess=False) # Make sure all the files exist as expected. assert os.path.exists(output_folder / '.text_to_mds_conversion_done') assert os.path.exists(output_folder / 'index.json') assert os.path.exists(output_folder / 'shard.00000.mds.zstd') + # Test reprocessing. + with pytest.raises(FileExistsError): + call_convert_text_to_mds(reprocess=True) + + shutil.rmtree(output_folder) + + call_convert_text_to_mds(reprocess=True) + def test_is_already_processed(tmp_path: pathlib.Path): tmp_path_str = str(tmp_path) diff --git a/tests/a_scripts/eval/test_eval.py b/tests/a_scripts/eval/test_eval.py index e8d86903dc..c9dfb88732 100644 --- a/tests/a_scripts/eval/test_eval.py +++ b/tests/a_scripts/eval/test_eval.py @@ -71,7 +71,6 @@ def test_icl_eval(eval_cfg: Union[om.ListConfig, om.DictConfig], capfd: Any, assert expected_results in out -@pytest.mark.gpu def test_loader_eval(capfd: Any, mock_saved_model_path: Any, tmp_path: pathlib.Path): diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 28fb9219f8..deed181475 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -300,6 +300,8 @@ def test_huggingface_conversion_callback_interval( mlflow_logger_mock.save_model = MagicMock() mlflow_logger_mock.register_model = MagicMock() mlflow_logger_mock.model_registry_prefix = '' + mlflow_logger_mock._experiment_id = 'mlflow-experiment-id' + mlflow_logger_mock._run_id = 'mlflow-run-id' trainer = Trainer( model=original_model, device='gpu', @@ -534,6 +536,8 @@ def test_huggingface_conversion_callback( mlflow_logger_mock.save_model = MagicMock() mlflow_logger_mock.register_model = MagicMock() mlflow_logger_mock.model_registry_prefix = '' + mlflow_logger_mock._experiment_id = 'mlflow-experiment-id' + mlflow_logger_mock._run_id = 'mlflow-run-id' trainer = Trainer( model=original_model, device='gpu', diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index bf818347a0..44d0442a87 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -7,7 +7,9 @@ import shutil import tempfile from argparse import Namespace -from typing import Literal, Optional, Union +from contextlib import nullcontext as does_not_raise +from pathlib import Path +from typing import ContextManager, Literal, Optional, Union from unittest.mock import MagicMock import pytest @@ -23,6 +25,8 @@ from llmfoundry.data import build_dataloader from llmfoundry.data.finetuning.tasks import (_ALLOWED_PROMPT_KEYS, _ALLOWED_RESPONSE_KEYS, + DOWNLOADED_FT_DATASETS_DIRPATH, + SUPPORTED_EXTENSIONS, _tokenize_formatted_example) from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper, build_text_dataloader, @@ -306,6 +310,51 @@ def test_finetuning_dataloader(decoder_only_format: bool, break +@pytest.mark.parametrize( + 'hf_name, hf_revision, expectation', + [('HuggingFaceH4/databricks_dolly_15k', None, does_not_raise()), + ('squad', '5fe18c', pytest.raises(FileNotFoundError))]) +def test_finetuning_dataloader_safe_load(hf_name: str, + hf_revision: Optional[str], + expectation: ContextManager): + cfg = DictConfig({ + 'name': 'finetuning', + 'dataset': { + 'hf_name': hf_name, + 'split': 'train', + 'max_seq_len': 8, + 'decoder_only_format': True, + 'shuffle': True, + 'safe_load': True, + 'hf_kwargs': { + 'revision': hf_revision + } + }, + 'drop_last': False, + 'num_workers': 0, + 'pin_memory': False, + 'prefetch_factor': None, + 'persistent_workers': False, + 'timeout': 0 + }) + + tokenizer = build_tokenizer('gpt2', {}) + + with expectation: + _ = build_finetuning_dataloader(cfg, tokenizer, 1) + + # If no raised errors, we should expect downloaded files with only safe file types. + if expectation == does_not_raise(): + download_dir = os.path.join(DOWNLOADED_FT_DATASETS_DIRPATH, hf_name) + downloaded_files = [ + file for _, _, files in os.walk(download_dir) for file in files + ] + assert len(downloaded_files) > 0 + assert all( + Path(file).suffix in SUPPORTED_EXTENSIONS + for file in downloaded_files) + + @pytest.mark.world_size(2) @pytest.mark.gpu @pytest.mark.parametrize('dataset_size', [4, 8]) @@ -441,12 +490,16 @@ def test_finetuning_dataloader_custom_split(tmp_path: pathlib.Path, split: str): def mock_get_file(path: str, destination: str, overwrite: bool = False): - make_tiny_ft_dataset(path=destination, size=16) + if Path(destination).suffix == '.jsonl': + make_tiny_ft_dataset(path=destination, size=16) + else: + raise FileNotFoundError( + f'Test error in mock_get_file. {path} does not exist.') @pytest.mark.parametrize('split', ['train', 'custom', 'custom-dash', 'data']) def test_finetuning_dataloader_custom_split_remote( - tmp_path: pathlib.Path, split: str, monkeypatch: pytest.MonkeyPatch): + split: str, monkeypatch: pytest.MonkeyPatch): tokenizer_name = 'gpt2' max_seq_len = 2048 diff --git a/tests/data_utils.py b/tests/data_utils.py index a0ad6bcd13..0afe3fa1a1 100644 --- a/tests/data_utils.py +++ b/tests/data_utils.py @@ -3,9 +3,9 @@ import json import os -import pathlib import shutil from argparse import Namespace +from pathlib import Path from typing import Optional from omegaconf import DictConfig @@ -26,6 +26,8 @@ def make_tiny_ft_dataset( start_token: Optional[str] = None, end_token: Optional[str] = None, ): + if Path(path).suffix != '.jsonl': + raise ValueError(f'Path {path} must be a jsonl file.') good_sample = {'prompt': 'hello', 'response': 'goodbye'} samples = [good_sample] * size if add_bad_data_dropped: @@ -77,7 +79,7 @@ def make_tiny_ft_dataset( _f.write('\n') -def create_c4_dataset_xxsmall(path: pathlib.Path) -> str: +def create_c4_dataset_xxsmall(path: Path) -> str: """Creates a small mocked version of the C4 dataset.""" c4_dir = os.path.join(path, f'my-copy-c4') downloaded_split = 'val_xxsmall' # very fast to convert @@ -109,7 +111,7 @@ def create_c4_dataset_xxsmall(path: pathlib.Path) -> str: return c4_dir -def create_arxiv_dataset(path: pathlib.Path) -> str: +def create_arxiv_dataset(path: Path) -> str: """Creates an arxiv dataset.""" arxiv_dir = os.path.join(path, f'my-copy-arxiv') downloaded_split = 'train' diff --git a/tests/fixtures/autouse.py b/tests/fixtures/autouse.py index 75caa6c941..ccbe1b69f7 100644 --- a/tests/fixtures/autouse.py +++ b/tests/fixtures/autouse.py @@ -17,12 +17,8 @@ @pytest.fixture(autouse=True) def initialize_dist(request: pytest.FixtureRequest): """Initialize the default PyTorch distributed process group for tests.""" - # should we just always initialize dist like in train.py? - _default = pytest.mark.world_size(1).mark - world_size = request.node.get_closest_marker('world_size', _default).args[0] gpu = request.node.get_closest_marker('gpu') - if world_size > 1: - dist.initialize_dist(get_device('gpu' if gpu is not None else 'cpu')) + dist.initialize_dist(get_device('gpu' if gpu is not None else 'cpu')) @pytest.fixture(autouse=True) diff --git a/tests/models/layers/test_flash_attn.py b/tests/models/layers/test_flash_attn.py index acefd2c42d..9471cdac6a 100644 --- a/tests/models/layers/test_flash_attn.py +++ b/tests/models/layers/test_flash_attn.py @@ -6,9 +6,13 @@ import pytest import torch -from llmfoundry.models.layers.attention import (flash_attn_fn, +from llmfoundry.models.layers.attention import (attn_bias_shape, + build_attn_bias, + check_alibi_support, + flash_attn_fn, gen_slopes, is_flash_v2_installed, triton_flash_attn_fn) +from llmfoundry.models.mpt.modeling_mpt import gen_flash_attn_padding_info @pytest.mark.gpu @@ -32,22 +36,24 @@ def test_gqa_kv_repetition(kv_n_heads: int): kv_n_heads * d).to(torch.bfloat16).cuda() value_1.requires_grad = True - output_1, _, _ = flash_attn_fn(query=query_1, - key=key_1, - value=value_1, - n_heads=n_heads, - kv_n_heads=kv_n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - attention_mask_in_length=None, - should_repeat_kv_for_gqa=True) + output_1, _, _ = flash_attn_fn( + query=query_1, + key=key_1, + value=value_1, + n_heads=n_heads, + kv_n_heads=kv_n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + flash_attn_padding_info=gen_flash_attn_padding_info( + bsz, seqlen_1, 0, query_1.device, None, None), + should_repeat_kv_for_gqa=True) output_1.sum().backward() @@ -58,22 +64,24 @@ def test_gqa_kv_repetition(kv_n_heads: int): value_2 = value_1.detach().clone() value_2.requires_grad = True - output_2, _, _ = flash_attn_fn(query=query_2, - key=key_2, - value=value_2, - n_heads=n_heads, - kv_n_heads=kv_n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - attention_mask_in_length=None, - should_repeat_kv_for_gqa=False) + output_2, _, _ = flash_attn_fn( + query=query_2, + key=key_2, + value=value_2, + n_heads=n_heads, + kv_n_heads=kv_n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + flash_attn_padding_info=gen_flash_attn_padding_info( + bsz, seqlen_1, 0, query_2.device, None, None), + should_repeat_kv_for_gqa=False) output_2.sum().backward() assert torch.allclose(output_1, output_2) @@ -111,6 +119,9 @@ def test_seq_id_masking_FA_v2(): [3, 2, 1, 0, 0, 0]]).to(torch.int64).cuda() + flash_attn_padding_info_1 = gen_flash_attn_padding_info( + bsz, seqlen_1, 0, query_1.device, attention_mask_in_length_1, None) + output_1, _, _ = flash_attn_fn( query=query_1, key=key_1, @@ -126,7 +137,7 @@ def test_seq_id_masking_FA_v2(): training=False, needs_weights=False, multiquery=False, - attention_mask_in_length=attention_mask_in_length_1) + flash_attn_padding_info=flash_attn_padding_info_1) output_1.sum().backward() @@ -138,21 +149,25 @@ def test_seq_id_masking_FA_v2(): value_2 = value_1.detach().clone()[:, seq_range[0]:seq_range[1], :] value_2.requires_grad = True - output_2, _, _ = flash_attn_fn(query=query_2, - key=key_2, - value=value_2, - n_heads=n_heads, - kv_n_heads=kv_n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - attention_mask_in_length=None) + flash_attn_padding_info_2 = gen_flash_attn_padding_info( + bsz, seq_range[1] - seq_range[0], 0, query_2.device, None, None) + + output_2, _, _ = flash_attn_fn( + query=query_2, + key=key_2, + value=value_2, + n_heads=n_heads, + kv_n_heads=kv_n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + flash_attn_padding_info=flash_attn_padding_info_2) output_2.sum().backward() assert torch.allclose(output_1[:, seq_range[0]:seq_range[1], :], @@ -193,23 +208,25 @@ def test_sliding_window(sliding_window_size: int): device=device) value_1.requires_grad = True - output_1, _, _ = flash_attn_fn(query=query_1, - key=key_1, - value=value_1, - n_heads=n_heads, - kv_n_heads=n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - attention_mask_in_length=None, - should_repeat_kv_for_gqa=True, - sliding_window_size=sliding_window_size) + output_1, _, _ = flash_attn_fn( + query=query_1, + key=key_1, + value=value_1, + n_heads=n_heads, + kv_n_heads=n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + flash_attn_padding_info=gen_flash_attn_padding_info( + bsz, seqlen_1, 0, query_1.device, None, None), + should_repeat_kv_for_gqa=True, + sliding_window_size=sliding_window_size) output_1.sum().backward() @@ -253,3 +270,113 @@ def test_sliding_window(sliding_window_size: int): ) <= 1e-2 + 1e-2 * torch.norm(key_2.grad) assert torch.norm(value_2.grad - value_1.grad # type: ignore ) <= 1e-2 + 1e-2 * torch.norm(value_2.grad) + + +@pytest.mark.gpu +@pytest.mark.skipif( + not check_alibi_support('flash'), + reason='ALiBi only supported by Flash Attention after v2.4.2.') +@pytest.mark.parametrize('n_heads', [1, 6, 8]) +def test_alibi_bias(n_heads: int): + # Test that sliding window attention works as expected. + dtype = torch.bfloat16 + device = 'cuda' + d = 128 + seqlen_1 = 8 + bsz = 2 + + query_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(dtype=dtype, + device=device) + query_1.requires_grad = True + key_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(dtype=dtype, + device=device) + key_1.requires_grad = True + value_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(dtype=dtype, + device=device) + value_1.requires_grad = True + alibi_slopes_1 = gen_slopes(n_heads=n_heads, + alibi_bias_max=8, + device=torch.device(device), + return_1d=True) + output_1, _, _ = flash_attn_fn( + query=query_1, + key=key_1, + value=value_1, + n_heads=n_heads, + kv_n_heads=n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + flash_attn_padding_info=gen_flash_attn_padding_info( + bsz, seqlen_1, 0, query_1.device, None, None), + should_repeat_kv_for_gqa=True, + alibi_slopes=alibi_slopes_1) + + output_1.sum().backward() + + query_2 = query_1.detach().clone() + query_2.requires_grad = True + key_2 = key_1.detach().clone() + key_2.requires_grad = True + value_2 = value_1.detach().clone() + value_2.requires_grad = True + + def gen_bias(): + causal = True + bs = attn_bias_shape('triton', + n_heads, + seqlen_1, + True, + prefix_lm=False, + use_sequence_id=False, + causal=causal) + + attn_bias = torch.zeros(*bs, device=device) + attn_bias = build_attn_bias( + 'triton', + attn_bias, + n_heads, + seqlen_1, + causal=causal, + alibi=True, + alibi_bias_max=8, + ) + return attn_bias + + attn_bias_2 = gen_bias() + + output_2, _, _ = triton_flash_attn_fn( + query=query_2, + key=key_2, + value=value_2, + n_heads=n_heads, + kv_n_heads=n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=attn_bias_2, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + ) + + output_2.sum().backward() + + assert torch.allclose(output_1, output_2) + assert (query_2.grad is not None) and (query_1.grad is not None) + assert torch.norm(query_2.grad - + query_1.grad) <= 1e-2 + 1e-2 * torch.norm(query_2.grad) + assert (key_2.grad is not None) and (key_1.grad is not None) + assert torch.norm(key_2.grad - + key_1.grad) <= 1e-2 + 1e-2 * torch.norm(key_2.grad) + assert (value_2.grad is not None) and (value_1.grad is not None) + assert torch.norm(value_2.grad - + value_1.grad) <= 1e-2 + 1e-2 * torch.norm(value_2.grad) diff --git a/tests/models/layers/test_flash_triton_torch.py b/tests/models/layers/test_flash_triton_torch.py index 454fda311d..2f992cd92f 100644 --- a/tests/models/layers/test_flash_triton_torch.py +++ b/tests/models/layers/test_flash_triton_torch.py @@ -6,9 +6,11 @@ from omegaconf import OmegaConf as om from llmfoundry.models.layers import attention -from llmfoundry.models.layers.attention import is_flash_v2_installed +from llmfoundry.models.layers.attention import (check_alibi_support, gen_slopes, + is_flash_v2_installed) from llmfoundry.models.mpt.modeling_mpt import (apply_sequence_id, gen_attention_mask_in_length, + gen_flash_attn_padding_info, gen_rotary_embedding) @@ -20,7 +22,7 @@ def allclose_helper(t0: torch.Tensor, @pytest.mark.gpu -@pytest.mark.parametrize('attn_impl_0,attn_impl_1', [ +@pytest.mark.parametrize('attn_impl_0, attn_impl_1', [ ('flash', 'triton'), ('flash', 'torch'), ('triton', 'torch'), @@ -74,9 +76,9 @@ def test_attn_impl(attn_impl_0: str, """ alibi = pos_emb_config['alibi'] rope = pos_emb_config['rope'] - if alibi and (attn_impl_0 == 'flash' or attn_impl_1 == 'flash'): - pytest.skip('flash attn does not support alibi') - + if alibi and not (check_alibi_support(attn_impl_0) and + check_alibi_support(attn_impl_1)): + pytest.skip('flash attention below v2.4.2 does not support alibi.') if rope and (pos_emb_config['rope_impl'] == 'dail') and (not is_flash_v2_installed()): pytest.skip('dail implementation of rope requires flash attention 2.') @@ -163,6 +165,13 @@ def gen_bias(attn_impl: str): attn_uses_sequence_id=attn_uses_sequence_id, attn_impl=attn_impl_0, attention_mask=attention_mask) + + flash_attn_padding_info_0 = {} + if attn_impl_0 == 'flash': + flash_attn_padding_info_0 = gen_flash_attn_padding_info( + n, s, 0, torch.device(device), attention_mask_in_length_0, + attention_mask) + attention_mask_in_length_1 = gen_attention_mask_in_length( sequence_id=sequence_id, S=s, @@ -170,6 +179,12 @@ def gen_bias(attn_impl: str): attn_impl=attn_impl_1, attention_mask=attention_mask) + flash_attn_padding_info_1 = {} + if attn_impl_1 == 'flash': + flash_attn_padding_info_1 = gen_flash_attn_padding_info( + n, s, 0, torch.device(device), attention_mask_in_length_1, + attention_mask) + x0 = torch.randn(n, s, f).to(device) x1 = x0.clone().detach() x0.requires_grad = True @@ -177,6 +192,12 @@ def gen_bias(attn_impl: str): with torch.autocast(x0.device.type): attn_bias_0 = gen_bias(attn_impl_0) + alibi_slopes_0 = None + if alibi and attn_impl_0 == 'flash': + alibi_slopes_0 = gen_slopes(n_heads=cfg.n_heads, + alibi_bias_max=8, + device=torch.device(device), + return_1d=True) rotary_emb_w_meta_info = None if rope: rotary_embedding = gen_rotary_embedding( @@ -209,15 +230,23 @@ def gen_bias(attn_impl: str): attention_mask=attention_mask, rotary_emb_w_meta_info=rotary_emb_w_meta_info, is_causal=True, - attention_mask_in_length=attention_mask_in_length_0) + flash_attn_padding_info=flash_attn_padding_info_0, + alibi_slopes=alibi_slopes_0) attn_bias_1 = gen_bias(attn_impl_1) + alibi_slopes_1 = None + if alibi and attn_impl_1 == 'flash': + alibi_slopes_1 = gen_slopes(n_heads=cfg.n_heads, + alibi_bias_max=8, + device=torch.device(device), + return_1d=True) y1, _, _ = attn1(x1, past_key_value=None, attn_bias=attn_bias_1, attention_mask=attention_mask, rotary_emb_w_meta_info=rotary_emb_w_meta_info, is_causal=True, - attention_mask_in_length=attention_mask_in_length_1) + flash_attn_padding_info=flash_attn_padding_info_1, + alibi_slopes=alibi_slopes_1) y0 *= attention_mask.unsqueeze(-1) y1 *= attention_mask.unsqueeze(-1) @@ -298,11 +327,16 @@ def gen_tca_mask(): x1.requires_grad = True with torch.autocast(x0.device.type): + flash_attn_padding_info = None + if attn_impl == 'flash': + flash_attn_padding_info = gen_flash_attn_padding_info( + n, s, 0, torch.device(device), None, attention_mask) y0, _, _ = mmhsa(x0, past_key_value=None, attn_bias=None, attention_mask=attention_mask, - is_causal=True) + is_causal=True, + flash_attn_padding_info=flash_attn_padding_info) y1, _ = tmhsa(x1, x1, x1, @@ -372,11 +406,16 @@ def test_grouped_attention_heads(attn_impl: str, x0.requires_grad = True with torch.autocast(x0.device.type): + flash_attn_padding_info = None + if attn_impl == 'flash': + flash_attn_padding_info = gen_flash_attn_padding_info( + n, s, 0, torch.device(device), None, attention_mask) y0, _, _ = mmhsa(x0, past_key_value=None, attn_bias=None, attention_mask=attention_mask, - is_causal=True) + is_causal=True, + flash_attn_padding_info=flash_attn_padding_info) y0 *= attention_mask.unsqueeze(-1) loss0 = y0.sum() diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 7bccad089d..64a92f6cc6 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -27,7 +27,8 @@ from llmfoundry import COMPOSER_MODEL_REGISTRY, ComposerHFCausalLM from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss from llmfoundry.models.layers import NORM_CLASS_REGISTRY, build_alibi_bias -from llmfoundry.models.layers.attention import is_flash_v2_installed +from llmfoundry.models.layers.attention import (check_alibi_support, + is_flash_v2_installed) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM from llmfoundry.utils import build_tokenizer @@ -647,8 +648,8 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool, def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict): # Testing the output of concatenated sequence with sequence id masking vs individual sequences. alibi = pos_emb_config['alibi'] - if alibi and attention_impl == 'flash': - pytest.skip(f'alibi only implemented with torch and triton attention.') + if alibi and not check_alibi_support(attention_impl): + pytest.skip(f'flash attention below v2.4.2 does not support alibi.') rope = pos_emb_config['rope'] if rope and pos_emb_config[ @@ -766,8 +767,8 @@ def test_forward_with_padding(attention_impl: str, pos_emb_config: dict, tie_word_embeddings: bool): # Test that different placement of padding does not affect the output. alibi = pos_emb_config['alibi'] - if alibi and attention_impl == 'flash': - pytest.skip(f'alibi only implemented with torch and triton attention.') + if alibi and not check_alibi_support(attention_impl): + pytest.skip(f'flash attention below v2.4.2 does not support alibi.') rope = pos_emb_config['rope'] if rope and pos_emb_config[ @@ -1028,8 +1029,8 @@ def test_generate(attention_impl: str, precision: str, pos_emb_config: dict, tie_word_embeddings: bool): # Test that generate works, and produces the same output with or without # padding in the input. - if pos_emb_config['alibi'] and attention_impl == 'flash': - pytest.skip(f'alibi only implemented with torch and triton attention.') + if pos_emb_config['alibi'] and not check_alibi_support(attention_impl): + pytest.skip(f'flash attention below v2.4.2 does not support alibi.') if pos_emb_config['rope'] and pos_emb_config[ 'rope_impl'] == 'dail' and not is_flash_v2_installed(): @@ -1277,8 +1278,8 @@ def test_save_from_pretrained(tmp_path: pathlib.Path): }]) def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict): # Tests that the result is the same with or without padding when using kv caching - if pos_emb_config['alibi'] and attn_impl == 'flash': - pytest.skip(f'alibi only implemented with torch and triton attention.') + if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): + pytest.skip(f'flash attention below v2.4.2 does not support alibi.') if pos_emb_config['rope'] and pos_emb_config[ 'rope_impl'] == 'dail' and not is_flash_v2_installed(): pytest.skip( @@ -1414,8 +1415,8 @@ def test_forward_with_cache(attn_impl: str, pos_emb_config: dict, tie_word_embeddings: bool): # Test that model forward with and without the key-value cache produces the # same output. - if pos_emb_config['alibi'] and attn_impl == 'flash': - pytest.skip(f'alibi only implemented with torch and triton attention.') + if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): + pytest.skip(f'flash attention below v2.4.2 does not support alibi.') if pos_emb_config['rope'] and pos_emb_config[ 'rope_impl'] == 'dail' and not is_flash_v2_installed(): @@ -1551,8 +1552,8 @@ def test_forward_with_cache(attn_impl: str, pos_emb_config: dict, @pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_generate_with_past_kv(attn_impl: str, pos_emb_config: dict, tie_word_embeddings: bool): - if pos_emb_config['alibi'] and attn_impl == 'flash': - pytest.skip(f'alibi only implemented with torch and triton attention.') + if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): + pytest.skip(f'flash attention below v2.4.2 does not support alibi.') if pos_emb_config['rope'] and pos_emb_config[ 'rope_impl'] == 'dail' and not is_flash_v2_installed(): pytest.skip( @@ -1658,8 +1659,8 @@ def test_generation_kwargs_dont_crash(attn_impl: str, generation_kwargs: Dict[str, Any], pos_emb_config: dict, tie_word_embeddings: bool): - if pos_emb_config['alibi'] and attn_impl == 'flash': - pytest.skip(f'alibi only implemented with torch and triton attention.') + if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): + pytest.skip(f'flash attention below v2.4.2 does not support alibi.') if pos_emb_config['rope'] and pos_emb_config[ 'rope_impl'] == 'dail' and not is_flash_v2_installed(): @@ -1847,8 +1848,8 @@ def test_alibi_vs_hf(): }]) def test_forward_with_output_attentions_and_output_hidden_states( attn_impl: str, pos_emb_config: dict): - if pos_emb_config['alibi'] and attn_impl == 'flash': - pytest.skip(f'alibi only implemented with torch and triton attention.') + if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): + pytest.skip(f'flash attention below v2.4.2 does not support alibi.') if attn_impl in ['flash', 'triton']: pytest.skip(f'output_attentions only implemented with torch attention.') if pos_emb_config['rope'] and pos_emb_config[ diff --git a/tests/models/test_rope_dail_vs_hf.py b/tests/models/test_rope_dail_vs_hf.py index 70a00470f9..33c3d3c052 100644 --- a/tests/models/test_rope_dail_vs_hf.py +++ b/tests/models/test_rope_dail_vs_hf.py @@ -7,7 +7,8 @@ from omegaconf import OmegaConf as om from llmfoundry.models.layers.attention import is_flash_v2_installed -from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding +from llmfoundry.models.mpt.modeling_mpt import (gen_flash_attn_padding_info, + gen_rotary_embedding) @pytest.mark.gpu @@ -104,14 +105,20 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): attn_bias=None, attention_mask=attention_mask, rotary_emb_w_meta_info=dail_rope_w_meta_info, - is_causal=True) + is_causal=True, + flash_attn_padding_info=gen_flash_attn_padding_info( + batch_size, seq_len, 0, torch.device(device), None, + attention_mask)) y1, _, _ = attn1(x1, past_key_value=None, attn_bias=None, attention_mask=attention_mask, rotary_emb_w_meta_info=hf_rope_w_meta_info, - is_causal=True) + is_causal=True, + flash_attn_padding_info=gen_flash_attn_padding_info( + batch_size, seq_len, 0, torch.device(device), None, + attention_mask)) y0 *= attention_mask.unsqueeze(-1) y1 *= attention_mask.unsqueeze(-1) diff --git a/tests/optim/test_scheduler.py b/tests/optim/test_scheduler.py index 5b9d45a141..811088bd62 100644 --- a/tests/optim/test_scheduler.py +++ b/tests/optim/test_scheduler.py @@ -88,7 +88,7 @@ def test_scheduler_units_match_error(state_unit: str, warmup_unit: str, t_warmup=f'10{warmup_unit}', t_scale=f'10{scale_unit}', t_cooldown=f'10{cooldown_unit}') - with pytest.raises(ValueError, match='does not match'): + with pytest.raises(ValueError, match='must match'): _ = scheduler(state, 1.0) diff --git a/tests/tokenizers/test_tiktoken.py b/tests/tokenizers/test_tiktoken.py index 1ade2ea156..aca269af82 100644 --- a/tests/tokenizers/test_tiktoken.py +++ b/tests/tokenizers/test_tiktoken.py @@ -108,6 +108,12 @@ 'Please summarize the goals in this text:\n\nGoing outside has benefits include reducing stress and triggering the relaxation response, which can help us not only feel better mentally, but even heal faster from physical ailments.', 'role': 'user' +}, { + 'content': 'You should go outside and touch grass.', + 'role': 'assistant' +}, { + 'content': 'What else can I do?', + 'role': 'user' }]] MULTI_TURN_GENERATE_STRING = [ @@ -118,6 +124,10 @@ Going outside has benefits include reducing stress and triggering the relaxation response, which can help us not only feel better mentally, but even heal faster from physical ailments.<|im_end|> <|im_start|>assistant +You should go outside and touch grass.<|im_end|><|endoftext|> +<|im_start|>user +What else can I do?<|im_end|> +<|im_start|>assistant """ ] @@ -338,6 +348,7 @@ def test_additional_special_tokens(model_name: Optional[str], encoding_name: Optional[str], tmp_path: pathlib.Path): special_token_to_add = '<|im_start|>' + input_string = special_token_to_add + ' hello' wrapped_tokenizer, _, _ = get_tokenizers_for_testing( model_name, encoding_name, @@ -345,12 +356,15 @@ def test_additional_special_tokens(model_name: Optional[str], add_bos_token=False, add_eos_token=False, additional_special_tokens=[special_token_to_add]) - encoded_outputs = wrapped_tokenizer(special_token_to_add + - ' hello')['input_ids'] + encoded_outputs = wrapped_tokenizer(input_string)['input_ids'] assert encoded_outputs[0] == wrapped_tokenizer.vocab_size assert len(encoded_outputs) == 2 + decoded_outputs = wrapped_tokenizer.decode( + encoded_outputs, spaces_between_special_tokens=False) + assert decoded_outputs == input_string + @pytest.mark.parametrize('model_name,encoding_name', MODEL_ENCODING_NAME_PARAMETRIZATION) @@ -386,3 +400,15 @@ def test_chat_formatting(model_name: Optional[str], chat_str = wrapped_tokenizer.apply_chat_template( dict_chats, tokenize=False, add_generation_prompt=True) assert chat_str == MULTI_TURN_GENERATE_STRING[i] + + +def test_tiktoken_out_of_range(): + wrapped_tokenizer = TiktokenTokenizerWrapper(model_name='gpt-4',) + + # For gpt-4, 100256 is less than the vocab size, but is not a valid token + assert wrapped_tokenizer.decode([100256]) == '' + assert wrapped_tokenizer.decode(100256) == '' + + # For gpt-4, 1000000 is greater than the vocab size + assert wrapped_tokenizer.decode([1000000]) == '' + assert wrapped_tokenizer.decode(1000000) == '' diff --git a/tests/utils/test_builders.py b/tests/utils/test_builders.py index 9be6630075..303afc9b7d 100644 --- a/tests/utils/test_builders.py +++ b/tests/utils/test_builders.py @@ -135,14 +135,14 @@ def test_build_logger(): with pytest.raises(ValueError): _ = build_logger('unknown', {}) - logger_cfg = DictConfig({ + logger_cfg = { 'project': 'foobar', 'init_kwargs': { 'config': { 'foo': 'bar', } } - }) + } wandb_logger = build_logger('wandb', logger_cfg) # type: ignore assert isinstance(wandb_logger, WandBLogger) assert wandb_logger.project == 'foobar' diff --git a/tests/utils/test_model_download_utils.py b/tests/utils/test_model_download_utils.py index 27b9805cda..14749bdcd9 100644 --- a/tests/utils/test_model_download_utils.py +++ b/tests/utils/test_model_download_utils.py @@ -16,11 +16,9 @@ from transformers.utils import WEIGHTS_INDEX_NAME as PYTORCH_WEIGHTS_INDEX_NAME from transformers.utils import WEIGHTS_NAME as PYTORCH_WEIGHTS_NAME -from llmfoundry.utils.model_download_utils import (DEFAULT_IGNORE_PATTERNS, - PYTORCH_WEIGHTS_PATTERN, - SAFE_WEIGHTS_PATTERN, - download_from_cache_server, - download_from_hf_hub) +from llmfoundry.utils.model_download_utils import ( + DEFAULT_IGNORE_PATTERNS, PYTORCH_WEIGHTS_PATTERN, SAFE_WEIGHTS_PATTERN, + download_from_hf_hub, download_from_http_fileserver) # ======================== download_from_hf_hub tests ======================== @@ -103,15 +101,18 @@ def test_download_from_hf_hub_weights_pref(mock_list_repo_files: MagicMock, repo_files: List[str], expected_ignore_patterns: List[str]): test_repo_id = 'test_repo_id' + save_dir = 'save_dir' mock_list_repo_files.return_value = repo_files - download_from_hf_hub(test_repo_id, prefer_safetensors=prefer_safetensors) + download_from_hf_hub(test_repo_id, + save_dir=save_dir, + prefer_safetensors=prefer_safetensors) mock_snapshot_download.assert_called_once_with( test_repo_id, - cache_dir=None, + local_dir=save_dir, + allow_patterns=None, ignore_patterns=expected_ignore_patterns, - token=None, - ) + token=None) @mock.patch('huggingface_hub.snapshot_download') @@ -121,10 +122,11 @@ def test_download_from_hf_hub_no_weights( mock_snapshot_download: MagicMock, ): test_repo_id = 'test_repo_id' + save_dir = 'save_dir' mock_list_repo_files.return_value = [] with pytest.raises(ValueError): - download_from_hf_hub(test_repo_id) + download_from_hf_hub(test_repo_id, save_dir) mock_snapshot_download.assert_not_called() @@ -148,12 +150,12 @@ def test_download_from_hf_hub_retry( mock_snapshot_download.side_effect = exception with pytest.raises((tenacity.RetryError, exception.__class__)): - download_from_hf_hub('test_repo_id') + download_from_hf_hub('test_repo_id', 'save_dir') assert mock_snapshot_download.call_count == expected_attempts -# ======================== download_from_cache_server tests ======================== +# ======================== download_from_http_fileserver tests ======================== ROOT_HTML = b""" @@ -182,51 +184,47 @@ def test_download_from_hf_hub_retry( @mock.patch.object(requests.Session, 'get') @mock.patch('os.makedirs') @mock.patch('builtins.open') -def test_download_from_cache_server(mock_open: MagicMock, - mock_makedirs: MagicMock, - mock_get: MagicMock): - cache_url = 'https://cache.com/' - model_name = 'model' - formatted_model_name = 'models--model' +def test_download_from_http_fileserver(mock_open: MagicMock, + mock_makedirs: MagicMock, + mock_get: MagicMock): + model_url = f'https://cache.com/models/model/' save_dir = 'save_dir/' mock_open.return_value = MagicMock() def _server_response(url: str, **kwargs: Dict[str, Any]): - if url == urljoin(cache_url, f'{formatted_model_name}/blobs/'): + if url == model_url: return MagicMock(status_code=HTTPStatus.OK, content=ROOT_HTML) - if url == urljoin(cache_url, f'{formatted_model_name}/blobs/file1'): + if url == urljoin(model_url, 'file1'): return MagicMock(status_code=HTTPStatus.OK) - elif url == urljoin(cache_url, f'{formatted_model_name}/blobs/folder/'): + elif url == urljoin(model_url, 'folder/'): return MagicMock(status_code=HTTPStatus.OK, content=SUBFOLDER_HTML) - elif url == urljoin(cache_url, - f'{formatted_model_name}/blobs/folder/file2'): + elif url == urljoin(model_url, 'folder/file2'): return MagicMock(status_code=HTTPStatus.OK) else: return MagicMock(status_code=HTTPStatus.NOT_FOUND) mock_get.side_effect = _server_response - download_from_cache_server(model_name, cache_url, 'save_dir/') + download_from_http_fileserver(model_url, save_dir) - mock_open.assert_has_calls([ - mock.call(os.path.join(save_dir, formatted_model_name, 'blobs/file1'), - 'wb'), - mock.call( - os.path.join(save_dir, formatted_model_name, 'blobs/folder/file2'), - 'wb'), - ], - any_order=True) + mock_open.assert_has_calls( + [ + mock.call(os.path.join(save_dir, 'file1'), 'wb'), + mock.call(os.path.join(save_dir, 'folder/file2'), 'wb'), + ], + any_order=True, + ) @mock.patch.object(requests.Session, 'get') -def test_download_from_cache_server_unauthorized(mock_get: MagicMock): - cache_url = 'https://cache.com/' +def test_download_from_http_fileserver_unauthorized(mock_get: MagicMock): model_name = 'model' + cache_url = f'https://cache.com/models--{model_name}/blobs/' save_dir = 'save_dir/' mock_get.return_value = MagicMock(status_code=HTTPStatus.UNAUTHORIZED) with pytest.raises(PermissionError): - download_from_cache_server(model_name, cache_url, save_dir) + download_from_http_fileserver(cache_url, save_dir) @pytest.mark.parametrize(['exception', 'expected_attempts'], [ @@ -236,7 +234,7 @@ def test_download_from_cache_server_unauthorized(mock_get: MagicMock): ]) @mock.patch('tenacity.nap.time.sleep') @mock.patch('llmfoundry.utils.model_download_utils._recursive_download') -def test_download_from_cache_server_retry( +def test_download_from_http_fileserver_retry( mock_recursive_download: MagicMock, mock_sleep: MagicMock, # so the retry wait doesn't actually wait exception: BaseException, @@ -245,4 +243,4 @@ def test_download_from_cache_server_retry( mock_recursive_download.side_effect = exception with pytest.raises((tenacity.RetryError, exception.__class__)): - download_from_cache_server('model', 'cache_url', 'save_dir') + download_from_http_fileserver('cache_url', 'save_dir')