From 3a96b69965189876ff3bccceebb26d991e9bea72 Mon Sep 17 00:00:00 2001 From: Anna Date: Wed, 29 Nov 2023 10:29:07 -0800 Subject: [PATCH 1/6] Add script for doing bulk generation against an endpoint (#765) * Add script for doing bulk generation against an endpoint * more logging * warn * fix * format * asdfads * Add warning * updates * folder -> file * remove blank line * Support remote input * prompts -> inputs --- llmfoundry/utils/prompt_files.py | 58 +++++++ scripts/inference/endpoint_generate.py | 223 +++++++++++++++++++++++++ scripts/inference/hf_generate.py | 31 ++-- tests/test_prompt_files.py | 18 ++ 4 files changed, 309 insertions(+), 21 deletions(-) create mode 100644 llmfoundry/utils/prompt_files.py create mode 100644 scripts/inference/endpoint_generate.py create mode 100644 tests/test_prompt_files.py diff --git a/llmfoundry/utils/prompt_files.py b/llmfoundry/utils/prompt_files.py new file mode 100644 index 0000000000..40de19907a --- /dev/null +++ b/llmfoundry/utils/prompt_files.py @@ -0,0 +1,58 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import os +from typing import List, Optional + +PROMPTFILE_PREFIX = 'file::' + + +def load_prompts(prompts: List[str], + prompt_delimiter: Optional[str] = None) -> List[str]: + """Loads a set of prompts, both free text and from file. + + Args: + prompts (List[str]): List of free text prompts and prompt files + prompt_delimiter (Optional str): Delimiter for text file + If not provided, assumes the prompt file is a single prompt (non-delimited) + + Returns: + List of prompt string(s) + """ + prompt_strings = [] + for prompt in prompts: + if prompt.startswith(PROMPTFILE_PREFIX): + prompts = load_prompts_from_file(prompt, prompt_delimiter) + prompt_strings.extend(prompts) + else: + prompt_strings.append(prompt) + return prompt_strings + + +def load_prompts_from_file(prompt_path: str, + prompt_delimiter: Optional[str] = None) -> List[str]: + """Load a set of prompts from a text fie. + + Args: + prompt_path (str): Path for text file + prompt_delimiter (Optional str): Delimiter for text file + If not provided, assumes the prompt file is a single prompt (non-delimited) + + Returns: + List of prompt string(s) + """ + if not prompt_path.startswith(PROMPTFILE_PREFIX): + raise ValueError(f'prompt_path_str must start with {PROMPTFILE_PREFIX}') + + _, prompt_file_path = prompt_path.split(PROMPTFILE_PREFIX, maxsplit=1) + prompt_file_path = os.path.expanduser(prompt_file_path) + if not os.path.isfile(prompt_file_path): + raise FileNotFoundError( + f'{prompt_file_path=} does not match any existing files.') + + with open(prompt_file_path, 'r') as f: + prompt_string = f.read() + + if prompt_delimiter is None: + return [prompt_string] + return [i for i in prompt_string.split(prompt_delimiter) if i] diff --git a/scripts/inference/endpoint_generate.py b/scripts/inference/endpoint_generate.py new file mode 100644 index 0000000000..e78fecf59b --- /dev/null +++ b/scripts/inference/endpoint_generate.py @@ -0,0 +1,223 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Batch generate text completion results from an endpoint. + +Warning: This script is experimental and could change or be removed at any time +""" + +import asyncio +import copy +import logging +import math +import os +import tempfile +import time +from argparse import ArgumentParser, Namespace + +import pandas as pd +import requests +from composer.utils import (get_file, maybe_create_object_store_from_uri, + parse_uri) + +from llmfoundry.utils import prompt_files as utils + +logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') +log = logging.getLogger(__name__) + +ENDPOINT_API_KEY_ENV: str = 'ENDPOINT_API_KEY' +ENDPOINT_URL_ENV: str = 'ENDPOINT_URL' + +PROMPT_DELIMITER = '\n' + + +def parse_args() -> Namespace: + """Parse commandline arguments.""" + parser = ArgumentParser( + description='Call prompts against a text completions endpoint') + + ##### + # Path Parameters + parser.add_argument( + '-i', + '--inputs', + nargs='+', + help=f'List of strings, local datafiles (starting with {utils.PROMPTFILE_PREFIX}),' +\ + ' and/or remote object stores' + ) + parser.add_argument( + '--prompt-delimiter', + default='\n', + help= + 'Prompt delimiter for txt files. By default, a file is a single prompt') + + parser.add_argument('-o', + '--output-folder', + required=True, + help='Remote folder to save the output') + + ##### + # Generation Parameters + parser.add_argument( + '--rate-limit', + type=int, + default=75, + help='Max number of calls to make to the endpoint in a second') + parser.add_argument( + '--batch-size', + type=int, + default=10, + help='Max number of calls to make to the endpoint in a single request') + + ##### + # Endpoint Parameters + parser.add_argument( + '-e', + '--endpoint', + type=str, + help= + f'OpenAI-compatible text completions endpoint to query on. If not set, will read from {ENDPOINT_URL_ENV}' + ) + + parser.add_argument('--max-tokens', type=int, default=100) + parser.add_argument('--temperature', type=float, default=1.0) + parser.add_argument('--top-k', type=int, default=50) + parser.add_argument('--top-p', type=float, default=1.0) + return parser.parse_args() + + +async def main(args: Namespace) -> None: + # This is mildly experimental, so for now imports are not added as part of llm-foundry + try: + import aiohttp + except ImportError as e: + raise ImportError('Please install aiohttp') from e + + try: + from ratelimit import limits, sleep_and_retry + except ImportError as e: + raise ImportError('Please install ratelimit') from e + + if args.batch_size > args.rate_limit: + raise ValueError( + f'Batch size is {args.batch_size} but rate limit is set to {args.rate_limit} / s' + ) + + url = args.endpoint if args.endpoint else os.environ.get(ENDPOINT_URL_ENV) + if not url: + raise ValueError( + f'URL must be provided via --endpoint or {ENDPOINT_URL_ENV}') + + log.info(f'Using endpoint {url}') + + api_key = os.environ.get(ENDPOINT_API_KEY_ENV, '') + if not api_key: + log.warning(f'API key not set in {ENDPOINT_API_KEY_ENV}') + + new_inputs = [] + for prompt in args.inputs: + if prompt.startswith(utils.PROMPTFILE_PREFIX): + new_inputs.append(prompt) + continue + + input_object_store = maybe_create_object_store_from_uri(prompt) + if input_object_store is not None: + local_output_path = tempfile.TemporaryDirectory().name + get_file(prompt, str(local_output_path)) + log.info(f'Downloaded {prompt} to {local_output_path}') + prompt = f'{utils.PROMPTFILE_PREFIX}{local_output_path}' + + new_inputs.append(prompt) + + prompt_strings = utils.load_prompts(new_inputs, args.prompt_delimiter) + + cols = ['batch', 'prompt', 'output'] + param_data = { + 'max_tokens': args.max_tokens, + 'temperature': args.temperature, + 'top_k': args.top_k, + 'top_p': args.top_p, + } + + total_batches = math.ceil(len(prompt_strings) / args.batch_size) + log.info( + f'Generating {len(prompt_strings)} prompts in {total_batches} batches') + + @sleep_and_retry + @limits(calls=total_batches, period=1) # type: ignore + async def generate(session: aiohttp.ClientSession, batch: int, + prompts: list): + data = copy.copy(param_data) + data['prompt'] = prompts + headers = {'Authorization': api_key, 'Content-Type': 'application/json'} + req_start = time.time() + async with session.post(url, headers=headers, json=data) as resp: + if resp.ok: + try: + response = await resp.json() + except requests.JSONDecodeError: + raise Exception( + f'Bad response: {resp.status} {resp.reason}') + else: + raise Exception(f'Bad response: {resp.status} {resp.reason}') + + req_end = time.time() + n_compl = response['usage']['completion_tokens'] + n_prompt = response['usage']['prompt_tokens'] + req_latency = (req_end - req_start) + log.info(f'Completed batch {batch}: {n_compl:,} completion' + + f' tokens using {n_prompt:,} prompt tokens in {req_latency}s') + + res = pd.DataFrame(columns=cols) + + for r in response['choices']: + index = r['index'] + res.loc[len(res)] = [batch, prompts[index], r['text']] + return res + + res = pd.DataFrame(columns=cols) + batch = 0 + + gen_start = time.time() + async with aiohttp.ClientSession() as session: + tasks = [] + + for i in range(total_batches): + prompts = prompt_strings[i * args.batch_size:min( + (i + 1) * args.batch_size, len(prompt_strings))] + + tasks.append(generate(session, batch, prompts)) + batch += 1 + + results = await asyncio.gather(*tasks) + res = pd.concat(results) + + res.reset_index(drop=True, inplace=True) + + gen_end = time.time() + gen_latency = (gen_end - gen_start) + log.info(f'Generated {len(res)} prompts in {gen_latency}s, example data:') + log.info(res.head()) + + with tempfile.TemporaryDirectory() as tmp_dir: + file = 'output.csv' + local_path = os.path.join(tmp_dir, file) + res.to_csv(local_path, index=False) + + output_object_store = maybe_create_object_store_from_uri( + args.output_folder) + if output_object_store is not None: + _, _, output_folder_prefix = parse_uri(args.output_folder) + remote_path = os.path.join(output_folder_prefix, file) + output_object_store.upload_object(remote_path, local_path) + output_object_store.download_object + log.info(f'Uploaded results to {args.output_folder}/{file}') + else: + output_dir, _ = os.path.split(args.output_folder) + os.makedirs(output_dir, exist_ok=True) + os.rename(local_path, args.output_folder) + log.info(f'Saved results to {args.output_folder}') + + +if __name__ == '__main__': + asyncio.run(main(parse_args())) diff --git a/scripts/inference/hf_generate.py b/scripts/inference/hf_generate.py index 45ddc6b63e..6ac645e5b7 100644 --- a/scripts/inference/hf_generate.py +++ b/scripts/inference/hf_generate.py @@ -1,7 +1,6 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 import itertools -import os import random import time import warnings @@ -13,6 +12,8 @@ import torch from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from llmfoundry.utils import prompt_files as utils + def get_dtype(dtype: str): if dtype == 'fp32': @@ -62,9 +63,14 @@ def parse_args() -> Namespace: 'My name is', 'This is an explanation of deep learning to a five year old. Deep learning is', ], - help='Generation prompts. Use syntax "file::/path/to/prompt.txt" to load a ' +\ - 'prompt contained in a txt file.' + help='List of generation prompts or list of delimited files. Use syntax ' +\ + '"file::/path/to/prompt.txt" to load a prompt(s) contained in a txt file.' ) + parser.add_argument( + '--prompt-delimiter', + default=None, + help= + 'Prompt delimiter for txt files. By default, a file is a single prompt') parser.add_argument('--max_seq_len', type=int, default=None) parser.add_argument('--max_new_tokens', type=int, default=100) parser.add_argument('--max_batch_size', type=int, default=None) @@ -125,19 +131,6 @@ def parse_args() -> Namespace: return parser.parse_args() -def load_prompt_string_from_file(prompt_path_str: str): - if not prompt_path_str.startswith('file::'): - raise ValueError('prompt_path_str must start with "file::".') - _, prompt_file_path = prompt_path_str.split('file::', maxsplit=1) - prompt_file_path = os.path.expanduser(prompt_file_path) - if not os.path.isfile(prompt_file_path): - raise FileNotFoundError( - f'{prompt_file_path=} does not match any existing files.') - with open(prompt_file_path, 'r') as f: - prompt_string = ''.join(f.readlines()) - return prompt_string - - def maybe_synchronize(): if torch.cuda.is_available(): torch.cuda.synchronize() @@ -163,11 +156,7 @@ def main(args: Namespace) -> None: print(f'Using {model_dtype=}') # Load prompts - prompt_strings = [] - for prompt in args.prompts: - if prompt.startswith('file::'): - prompt = load_prompt_string_from_file(prompt) - prompt_strings.append(prompt) + prompt_strings = utils.load_prompts(args.prompts, args.prompt_delimiter) # Grab config first print(f'Loading HF Config...') diff --git a/tests/test_prompt_files.py b/tests/test_prompt_files.py new file mode 100644 index 0000000000..12a5d02999 --- /dev/null +++ b/tests/test_prompt_files.py @@ -0,0 +1,18 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +from llmfoundry.utils import prompt_files as utils + + +def test_load_prompt_strings(tmp_path: Path): + assert utils.load_prompts(['hello', 'goodbye']) == ['hello', 'goodbye'] + + with open(tmp_path / 'prompts.txt', 'w') as f: + f.write('hello goodbye') + + temp = utils.PROMPTFILE_PREFIX + str(tmp_path / 'prompts.txt') + assert utils.load_prompts( + [temp, temp, 'why'], + ' ') == ['hello', 'goodbye', 'hello', 'goodbye', 'why'] From 1191267195367b5ec6093ed7854b8f6daf1be2d3 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Wed, 29 Nov 2023 12:14:02 -0800 Subject: [PATCH 2/6] Only strip object names when creating new output path (#766) --- llmfoundry/utils/data_prep_utils.py | 11 +++++----- tests/test_convert_text_to_mds.py | 31 +++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/llmfoundry/utils/data_prep_utils.py b/llmfoundry/utils/data_prep_utils.py index 75e27b504f..a88e65ee94 100644 --- a/llmfoundry/utils/data_prep_utils.py +++ b/llmfoundry/utils/data_prep_utils.py @@ -96,15 +96,16 @@ def __init__( def __iter__(self): for object_name in self.object_names: - object_name = object_name.strip('/') - output_filename = os.path.join(self.output_folder, object_name) + # Default output_filename, used for local paths. + output_filename = object_name + + # Download objects if remote path. if self.object_store is not None: + output_filename = os.path.join(self.output_folder, + object_name.strip('/')) self.object_store.download_object(object_name=object_name, filename=output_filename, overwrite=True) - else: - # Inputs are local so we do not need to download them. - output_filename = object_name with open(output_filename) as _txt_file: txt = _txt_file.read() diff --git a/tests/test_convert_text_to_mds.py b/tests/test_convert_text_to_mds.py index 2d4878ebbb..ab8c25bc2d 100644 --- a/tests/test_convert_text_to_mds.py +++ b/tests/test_convert_text_to_mds.py @@ -188,6 +188,37 @@ def test_single_and_multi_process(merge_shard_groups: Mock, assert n_tokens == expected_n_tokens +def test_local_path(tmp_path: pathlib.Path): + # Input/output folders + input_folder = tmp_path / 'input' + output_folder = tmp_path / 'output' + + # Create input text data + os.makedirs(input_folder, exist_ok=True) + with open(input_folder / 'test.txt', 'w') as f: + f.write('test') + + # Convert text data to mds + convert_text_to_mds( + tokenizer_name='mosaicml/mpt-7b', + output_folder=str(output_folder), + input_folder=str(input_folder), + concat_tokens=1, + eos_text='', + bos_text='', + no_wrap=False, + compression='zstd', + processes=1, + args_str='Namespace()', + reprocess=False, + ) + + # Make sure all the files exist as expected. + assert os.path.exists(output_folder / '.text_to_mds_conversion_done') + assert os.path.exists(output_folder / 'index.json') + assert os.path.exists(output_folder / 'shard.00000.mds.zstd') + + def test_is_already_processed(tmp_path: pathlib.Path): tmp_path_str = str(tmp_path) args_str = 'Namespace(x = 5)' From 3100859905c1ed29e049e7c203cf70da8231f2e6 Mon Sep 17 00:00:00 2001 From: Anna Date: Thu, 30 Nov 2023 14:02:13 -0800 Subject: [PATCH 3/6] Add eval loader to eval script (#742) * Add eval loader to eval script * small input tests * updates * fix typing and formatting * fixes, add tests * remove circular dependency * tests pass * nits + small fixes * add metrics at the end, refactor to put icl/gauntlet as helpers * NOT * metrics instead of models, add unit tests --- llmfoundry/data/dataloader.py | 32 ++++----- llmfoundry/utils/builders.py | 81 +++++++++++++++++++++++ scripts/eval/eval.py | 53 +++++++++++---- scripts/train/train.py | 52 +++++---------- tests/data_utils.py | 98 +++++++++++++++++++++++++++- tests/test_builders.py | 118 +++++++++++++++++++++++++++++++++- tests/test_dataloader.py | 11 ++++ tests/test_eval.py | 89 +++++++++++++++++++++++++ tests/test_eval_inputs.py | 1 + tests/test_train_inputs.py | 2 +- tests/test_training.py | 97 ++-------------------------- 11 files changed, 469 insertions(+), 165 deletions(-) diff --git a/llmfoundry/data/dataloader.py b/llmfoundry/data/dataloader.py index 12741717be..63d47a65d5 100644 --- a/llmfoundry/data/dataloader.py +++ b/llmfoundry/data/dataloader.py @@ -11,6 +11,12 @@ from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader from llmfoundry.data.text_data import build_text_dataloader +LOADER_NAME_TO_FUNCTION = { + 'text': build_text_dataloader, + 'text_denoising': build_text_denoising_dataloader, + 'finetuning': build_finetuning_dataloader, +} + def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size: int) -> DataSpec: @@ -22,23 +28,9 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size (int): The size of the batches (number of examples) that the dataloader will produce. """ - if cfg.name == 'text': - return build_text_dataloader( - cfg, - tokenizer, - device_batch_size, - ) - elif cfg.name == 'text_denoising': - return build_text_denoising_dataloader( - cfg, - tokenizer, - device_batch_size, - ) - elif cfg.name == 'finetuning': - return build_finetuning_dataloader( - cfg, - tokenizer, - device_batch_size, - ) - else: - raise ValueError(f'Not sure how to build dataloader with config: {cfg}') + if cfg.name not in LOADER_NAME_TO_FUNCTION: + allowed = ', '.join(LOADER_NAME_TO_FUNCTION.keys()) + raise ValueError(f'Expected dataloader name to be one of {allowed}' + + f' but found name "{cfg.name}" in config: {cfg}') + + return LOADER_NAME_TO_FUNCTION[cfg.name](cfg, tokenizer, device_batch_size) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 14196c3ef9..a672fbee55 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -28,12 +28,14 @@ from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om from torch.optim.optimizer import Optimizer +from torchmetrics import Metric from transformers import AutoTokenizer, PreTrainedTokenizerBase from llmfoundry.callbacks import (EvalGauntlet, FDiffMetrics, GlobalLRScaling, HuggingFaceCheckpointer, LayerFreezing, MonolithicCheckpointSaver, ScheduledGarbageCollector) +from llmfoundry.data.dataloader import build_dataloader from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion, DecoupledLionW, DecoupledLionW_8bit) from llmfoundry.optim.scheduler import InverseSquareRootWithWarmupScheduler @@ -42,6 +44,85 @@ log = logging.getLogger(__name__) +def build_evaluators( + eval_loader_config: Optional[Union[DictConfig, ListConfig]], + icl_tasks_config: Optional[Union[str, ListConfig]], + eval_gauntlet_config: Optional[Union[str, DictConfig]], + *, + tokenizer: PreTrainedTokenizerBase, + device_eval_batch_size: int, + icl_seq_len: int, + icl_subset_num_batches: Optional[int], +) -> Tuple[List[Evaluator], List[str], Optional[EvalGauntlet]]: + + evaluators = [] + if eval_loader_config is not None: + evaluators = build_eval_loaders( + eval_loader_config, + tokenizer, + device_eval_batch_size, + ) + + logger_keys = [] + eval_gauntlet_callback = None + if icl_tasks_config is not None: + icl_evaluators, logger_keys, eval_gauntlet_callback = build_icl_data_and_gauntlet( + icl_tasks_config, + eval_gauntlet_config, + tokenizer, + device_eval_batch_size, + icl_seq_len, + icl_subset_num_batches, + ) + evaluators.extend(icl_evaluators) + + return evaluators, logger_keys, eval_gauntlet_callback + + +def build_eval_loaders( + eval_loader_config: Union[DictConfig, ListConfig], + tokenizer: PreTrainedTokenizerBase, + device_eval_batch_size: int, +) -> List[Evaluator]: + evaluators: List[Evaluator] = [] + if isinstance(eval_loader_config, ListConfig): + eval_configs: ListConfig = eval_loader_config + is_multi_eval = True + else: + eval_configs = ListConfig([eval_loader_config]) + is_multi_eval = False + + for eval_config in eval_configs: + eval_dataloader = build_dataloader(eval_config, tokenizer, + device_eval_batch_size) + eval_loader: Evaluator = Evaluator( + label=f'eval/{eval_config.label}' if is_multi_eval else 'eval', + dataloader=eval_dataloader, + # Load the eval data to fail fast. metrics will get added + # later in add_metrics_to_eval_loaders, after the model is loaded + metric_names=[], + ) + evaluators.append(eval_loader) + return evaluators + + +def add_metrics_to_eval_loaders( + evaluators: List[Evaluator], + metrics: Dict[str, Metric], +) -> List[Evaluator]: + metric_names = list(metrics.keys()) + eval_loaders, other_evaluators = [], [] + for evaluator in evaluators: + if evaluator.metric_names == []: + evaluator.metric_names = metric_names + eval_loaders.append(evaluator) + else: + other_evaluators.append(evaluator) + + # Put the base eval_loaders first + return eval_loaders + other_evaluators + + def build_icl_data_and_gauntlet( icl_tasks_config: Union[str, ListConfig], eval_gauntlet_config: Optional[Union[str, DictConfig]], diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 02a5d1f862..369a894720 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -6,7 +6,7 @@ import sys import time import warnings -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import pandas as pd import torch @@ -21,13 +21,14 @@ from llmfoundry.models import MPTForCausalLM from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY -from llmfoundry.utils.builders import (build_icl_data_and_gauntlet, - build_logger, build_tokenizer) +from llmfoundry.utils.builders import (add_metrics_to_eval_loaders, + build_evaluators, build_logger, + build_tokenizer) from llmfoundry.utils.config_utils import pop_config, process_init_device def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, - num_retries: int) -> Optional[ComposerModel]: + num_retries: int) -> ComposerModel: try: from peft import PeftModel except ImportError as e: @@ -43,7 +44,8 @@ def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, } retries = 0 - while retries < num_retries: + composer_model_wrapper = None + while retries < num_retries and composer_model_wrapper is None: try: trust_remote_code = model_cfg.get('trust_remote_code', True) use_auth_token = model_cfg.get('use_auth_token', False) @@ -58,7 +60,6 @@ def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, composer_model_wrapper = COMPOSER_MODEL_REGISTRY[model_cfg.name]( peft_model, tokenizer) - return composer_model_wrapper except Exception as e: retries += 1 if retries >= num_retries: @@ -68,19 +69,21 @@ def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, f'Got exception {str(e)} while loading model {model_cfg.name}. {num_retries-retries} retries remaining' ) + assert composer_model_wrapper is not None + return composer_model_wrapper + def load_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, - fsdp_config: Optional[Dict], - num_retries: int) -> Optional[ComposerModel]: + fsdp_config: Optional[Dict], num_retries: int) -> ComposerModel: init_context = process_init_device(model_cfg, fsdp_config) retries = 0 + composer_model = None with init_context: - while retries < num_retries: + while retries < num_retries and composer_model is None: try: composer_model = COMPOSER_MODEL_REGISTRY[model_cfg.name]( model_cfg, tokenizer) - return composer_model except Exception as e: retries += 1 if retries >= num_retries: @@ -90,6 +93,9 @@ def load_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, f'Got exception {str(e)} while loading model {model_cfg.name}. {num_retries-retries} retries remaining' ) + assert composer_model is not None + return composer_model + def evaluate_model( model_cfg: DictConfig, @@ -100,6 +106,7 @@ def evaluate_model( max_seq_len: int, device_eval_batch_size: int, eval_gauntlet_config: Optional[Union[str, DictConfig]], + eval_loader_config: Optional[Union[DictConfig, ListConfig]], fsdp_config: Optional[Dict], num_retries: int, loggers_cfg: Dict[str, Any], @@ -118,9 +125,15 @@ def evaluate_model( tokenizer_kwargs = tokenizer_cfg.get('kwargs', {}) tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) - evaluators, logger_keys, eval_gauntlet_callback = build_icl_data_and_gauntlet( - icl_tasks, eval_gauntlet_config, tokenizer, device_eval_batch_size, - max_seq_len, icl_subset_num_batches) + evaluators, logger_keys, eval_gauntlet_callback = build_evaluators( + eval_loader_config, + icl_tasks, + eval_gauntlet_config, + tokenizer=tokenizer, + device_eval_batch_size=device_eval_batch_size, + icl_seq_len=max_seq_len, + icl_subset_num_batches=icl_subset_num_batches, + ) callbacks = [] if eval_gauntlet_callback is not None: @@ -143,6 +156,11 @@ def evaluate_model( composer_model = load_model(model_cfg.model, tokenizer, fsdp_config, num_retries) + # Now add the eval metrics + if eval_loader_config is not None: + train_metrics = composer_model.get_metrics(is_train=True) + evaluators = add_metrics_to_eval_loaders(evaluators, train_metrics) + if eval_gauntlet_df is None and eval_gauntlet_callback is not None: eval_gauntlet_df = pd.DataFrame( columns=['model_name'] + @@ -186,7 +204,7 @@ def evaluate_model( return (trainer, logger_keys, eval_gauntlet_callback, eval_gauntlet_df) -def main(cfg: DictConfig): +def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: om.resolve(cfg) model_configs: ListConfig = pop_config(cfg, 'models', must_exist=True) eval_gauntlet_config: Optional[Union[str, DictConfig]] = pop_config( @@ -228,6 +246,8 @@ def main(cfg: DictConfig): default_value='debug') # Optional Evaluation Parameters with default values + eval_loader_config: Optional[Union[DictConfig, ListConfig]] = pop_config( + cfg, 'eval_loader', must_exist=False, default_value=None) seed: int = pop_config(cfg, 'seed', must_exist=False, default_value=17) dist_timeout: Union[float, int] = pop_config(cfg, 'dist_timeout', @@ -274,6 +294,7 @@ def main(cfg: DictConfig): eval_gauntlet_df = None models_df = None composite_scores = None + trainers = [] for model_cfg in model_configs: (trainer, logger_keys, eval_gauntlet_callback, eval_gauntlet_df) = evaluate_model( @@ -285,6 +306,7 @@ def main(cfg: DictConfig): max_seq_len=max_seq_len, device_eval_batch_size=device_eval_batch_size, eval_gauntlet_config=eval_gauntlet_config, + eval_loader_config=eval_loader_config, fsdp_config=fsdp_config, num_retries=num_retries, loggers_cfg=loggers_cfg, @@ -292,6 +314,7 @@ def main(cfg: DictConfig): precision=precision, eval_gauntlet_df=eval_gauntlet_df, icl_subset_num_batches=icl_subset_num_batches) + trainers.append(trainer) if eval_gauntlet_callback is not None: composite_scores = eval_gauntlet_callback.eval_after_all( @@ -330,6 +353,8 @@ def main(cfg: DictConfig): assert models_df is not None print(models_df.to_markdown(index=False)) + return trainers, eval_gauntlet_df + def calculate_markdown_results(logger_keys: List[str], trainer: Trainer, benchmark_to_taxonomy: Dict[str, str], diff --git a/scripts/train/train.py b/scripts/train/train.py index 88f776375f..809f2fb09c 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -11,7 +11,6 @@ import torch from composer import Trainer -from composer.core import Evaluator from composer.core.callback import Callback from composer.loggers import MosaicMLLogger from composer.loggers.mosaicml_logger import (MOSAICML_ACCESS_TOKEN_ENV_VAR, @@ -26,10 +25,11 @@ from llmfoundry import (COMPOSER_MODEL_REGISTRY, ComposerHFCausalLM, MPTForCausalLM) from llmfoundry.data.dataloader import build_dataloader -from llmfoundry.utils.builders import (build_algorithm, build_callback, - build_icl_data_and_gauntlet, - build_logger, build_optimizer, - build_scheduler, build_tokenizer) +from llmfoundry.utils.builders import (add_metrics_to_eval_loaders, + build_algorithm, build_callback, + build_evaluators, build_logger, + build_optimizer, build_scheduler, + build_tokenizer) from llmfoundry.utils.config_utils import (log_config, pop_config, process_init_device, update_batch_size_info) @@ -526,31 +526,16 @@ def main(cfg: DictConfig) -> Trainer: ## Evaluation print('Building eval loader...') - evaluators = [] - eval_loaders = [] - if eval_loader_config is not None: - is_multi_eval = isinstance(eval_loader_config, ListConfig) - eval_configs = eval_loader_config if is_multi_eval else [ - eval_loader_config - ] - for eval_config in eval_configs: - eval_dataloader = build_dataloader(eval_config, tokenizer, - device_eval_batch_size) - eval_loader = Evaluator( - label=f'eval/{eval_config.label}' if is_multi_eval else 'eval', - dataloader=eval_dataloader, - metric_names=[], # we will add these after model is created - ) - eval_loaders.append(eval_loader) - - eval_gauntlet_callback = None - - if icl_tasks_config is not None: - icl_evaluators, _, eval_gauntlet_callback = build_icl_data_and_gauntlet( - icl_tasks_config, eval_gauntlet_config, tokenizer, - device_eval_batch_size, icl_seq_len if icl_seq_len else max_seq_len, - icl_subset_num_batches) - evaluators.extend(icl_evaluators) + eval_icl_seq_len: int = icl_seq_len if icl_seq_len else max_seq_len + evaluators, _, eval_gauntlet_callback = build_evaluators( + eval_loader_config, + icl_tasks_config, + eval_gauntlet_config, + tokenizer=tokenizer, + device_eval_batch_size=device_eval_batch_size, + icl_seq_len=eval_icl_seq_len, + icl_subset_num_batches=icl_subset_num_batches, + ) if eval_gauntlet_callback is not None: callbacks.append(eval_gauntlet_callback) @@ -581,11 +566,8 @@ def main(cfg: DictConfig) -> Trainer: # Now add the eval metrics if eval_loader_config is not None: - assert model.train_metrics is not None - eval_metric_names = list(model.train_metrics.keys()) - for eval_loader in eval_loaders: - eval_loader.metric_names = eval_metric_names - evaluators.insert(0, eval_loader) # Put the base eval_loaders first + train_metrics = model.get_metrics(is_train=True) + evaluators = add_metrics_to_eval_loaders(evaluators, train_metrics) # Build the Trainer print('Building trainer...') diff --git a/tests/data_utils.py b/tests/data_utils.py index 075933de7d..efb4f6d7cf 100644 --- a/tests/data_utils.py +++ b/tests/data_utils.py @@ -1,10 +1,26 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import json import os +import sys + +# Add repo root to path so we can import scripts and test it +repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.append(repo_dir) + +import json +import pathlib +import shutil +from argparse import Namespace from typing import Optional +from omegaconf import DictConfig +from omegaconf import OmegaConf as om + +from scripts.data_prep.convert_dataset_hf import main as main_hf # noqa: E402 +from scripts.data_prep.convert_dataset_json import \ + main as main_json # noqa: E402 + def make_tiny_ft_dataset( path: str, @@ -65,3 +81,83 @@ def make_tiny_ft_dataset( for sample in samples: _f.write(json.dumps(sample)) _f.write('\n') + + +def create_c4_dataset_xxsmall(path: pathlib.Path) -> str: + """Creates a small mocked version of the C4 dataset.""" + c4_dir = os.path.join(path, f'my-copy-c4') + downloaded_split = 'val_xxsmall' # very fast to convert + + # Hyperparameters from https://github.com/mosaicml/llm-foundry/blob/340a56658560ebceb2a3aa69d6e37813e415acd0/README.md#L188 + main_hf( + Namespace( + **{ + 'dataset': 'c4', + 'data_subset': 'en', + 'splits': [downloaded_split], + 'out_root': c4_dir, + 'compression': None, + 'concat_tokens': 2048, + 'tokenizer': 'EleutherAI/gpt-neox-20b', + 'tokenizer_kwargs': {}, + 'bos_text': '', + 'eos_text': '<|endoftext|>', + 'no_wrap': False, + 'num_workers': 8 + })) + + # copy the small downloaded_split to other c4 splits for mocking purposes + mocked_splits = ['train', 'val'] + for mocked_split in mocked_splits: + shutil.copytree(os.path.join(c4_dir, 'val_xxsmall'), + os.path.join(c4_dir, mocked_split)) + assert os.path.exists(c4_dir) + return c4_dir + + +def create_arxiv_dataset(path: pathlib.Path) -> str: + """Creates an arxiv dataset.""" + arxiv_dir = os.path.join(path, f'my-copy-arxiv') + downloaded_split = 'train' + + main_json( + Namespace( + **{ + 'path': 'data_prep/example_data/arxiv.jsonl', + 'out_root': arxiv_dir, + 'compression': None, + 'split': downloaded_split, + 'concat_tokens': None, + 'bos_text': None, + 'eos_text': None, + 'no_wrap': False, + 'num_workers': None + })) + + return arxiv_dir + + +def gpt_tiny_cfg(dataset_name: str, device: str): + """Create gpt tiny cfg.""" + conf_path: str = os.path.join(repo_dir, + 'scripts/train/yamls/pretrain/testing.yaml') + with open(conf_path) as f: + test_cfg = om.load(f) + assert isinstance(test_cfg, DictConfig) + + test_cfg.data_local = dataset_name + test_cfg.global_train_batch_size = 8 + test_cfg.device_eval_batch_size = 4 + test_cfg.device_train_microbatch_size = 4 + test_cfg.max_duration = '4ba' + test_cfg.eval_interval = '4ba' + test_cfg.run_name = 'gpt-mini-integration-test' + + if device == 'cpu': + test_cfg.model.init_device = 'cpu' + test_cfg.fsdp_config = None + test_cfg.model.attn_config.attn_impl = 'torch' + test_cfg.model.loss_fn = 'torch_crossentropy' + test_cfg.precision = 'fp32' + + return test_cfg diff --git a/tests/test_builders.py b/tests/test_builders.py index 7ac179720e..5c38ed8602 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -5,17 +5,22 @@ import unittest.mock as mock from copy import deepcopy from typing import Any, Dict, Union +from unittest.mock import MagicMock import pytest import torch import torch.nn as nn from composer.callbacks import Generate +from composer.core import Evaluator +from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om from transformers import PreTrainedTokenizerBase from llmfoundry.callbacks import HuggingFaceCheckpointer from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper -from llmfoundry.utils.builders import (build_callback, build_optimizer, +from llmfoundry.utils.builders import (add_metrics_to_eval_loaders, + build_callback, build_eval_loaders, + build_evaluators, build_optimizer, build_tokenizer) @@ -195,3 +200,114 @@ def test_build_optimizer(name: str, optimizer_config: Dict[str, Any], for n, p in model.named_parameters(): if re.search(param_str_match, n): assert id(p) in param_ids + + +def test_build_evaluators_empty(): + evaluators, logger_keys, eval_gauntlet_callback = build_evaluators( + None, + None, + None, + tokenizer=None, # type: ignore + device_eval_batch_size=1, + icl_seq_len=2, + icl_subset_num_batches=3) + assert evaluators == [] + assert logger_keys == [] + assert eval_gauntlet_callback is None + + +def test_build_eval_loaders(monkeypatch: pytest.MonkeyPatch): + tokenizer = TiktokenTokenizerWrapper(model_name='gpt-4') + + eval_loader_cfg = DictConfig({ + 'name': 'text', + 'dataset': { + # mocked, not needed + }, + 'drop_last': False, + 'num_workers': 8, + }) + monkeypatch.setattr('llmfoundry.data.text_data.StreamingTextDataset', + lambda *args, **kwargs: MagicMock()) + eval_loaders = build_eval_loaders(eval_loader_cfg, tokenizer, 2) + + assert len(eval_loaders) == 1 + + assert eval_loaders[0].label == 'eval' + assert eval_loaders[0].dataloader is not None + assert eval_loaders[0].metric_names == [] + + multi_eval_loader_cfg = ListConfig([ + { + 'name': 'text', + 'label': 'test1', + 'dataset': { + # mocked, not needed + }, + 'drop_last': False, + 'num_workers': 8, + }, + { + 'name': 'text', + 'label': 'test2', + 'dataset': { + # mocked, not needed + }, + 'drop_last': False, + 'num_workers': 8, + } + ]) + monkeypatch.setattr('llmfoundry.data.text_data.StreamingTextDataset', + lambda *args, **kwargs: MagicMock()) + eval_loaders2 = build_eval_loaders(multi_eval_loader_cfg, tokenizer, 2) + + assert len(eval_loaders2) == 2 + + assert eval_loaders2[0].label == 'eval/test1' + assert eval_loaders2[0].dataloader is not None + assert eval_loaders2[0].metric_names == [] + + assert eval_loaders2[1].label == 'eval/test2' + assert eval_loaders2[1].dataloader is not None + assert eval_loaders2[1].metric_names == [] + + +def test_add_metrics_to_eval_loaders(): + evaluators = [ + Evaluator( + label='first', + metric_names=['a', 'b'], + dataloader=None, # type: ignore + device_eval_microbatch_size=1, + ), + Evaluator( + label='second', + metric_names=[], + dataloader=None, # type: ignore + device_eval_microbatch_size=1, + ), + Evaluator( + label='third', + metric_names=['c'], + dataloader=None, # type: ignore + device_eval_microbatch_size=1, + ) + ] + + new_evaluators = add_metrics_to_eval_loaders( + evaluators, + { + 'new1': 'foo', + 'new2': 'bar' + }, # type: ignore + ) + assert len(new_evaluators) == 3 + + assert new_evaluators[0].label == 'second' + assert new_evaluators[0].metric_names == ['new1', 'new2'] + + assert new_evaluators[1].label == 'first' + assert new_evaluators[1].metric_names == ['a', 'b'] + + assert new_evaluators[2].label == 'third' + assert new_evaluators[2].metric_names == ['c'] diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index c35d29f74d..2e9039644b 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -21,6 +21,7 @@ from llmfoundry import (build_finetuning_dataloader, build_text_denoising_dataloader) +from llmfoundry.data import build_dataloader from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper, build_text_dataloader, get_tokens_per_batch_func) @@ -740,3 +741,13 @@ def test_token_counting_func_dataloader_setting( actual_token_count = dl.get_num_tokens_in_batch(batch_tokenized) assert actual_token_count == expected_token_count + + +def test_build_unknown_dataloader(): + cfg = DictConfig({ + 'name': 'unknown', + }) + tokenizer = MagicMock() + with pytest.raises(ValueError, + match='Expected dataloader name to be one of'): + _ = build_dataloader(cfg, tokenizer, 2) diff --git a/tests/test_eval.py b/tests/test_eval.py index 1217487b70..2fc96bb7ad 100644 --- a/tests/test_eval.py +++ b/tests/test_eval.py @@ -1,16 +1,21 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import copy import os +import pathlib import sys from typing import Any import omegaconf as om import pytest from composer import Trainer +from composer.loggers import InMemoryLogger from llmfoundry import COMPOSER_MODEL_REGISTRY from llmfoundry.utils import build_tokenizer +from tests.data_utils import (create_arxiv_dataset, create_c4_dataset_xxsmall, + gpt_tiny_cfg) # Add repo root to path so we can import scripts and test it repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) @@ -66,3 +71,87 @@ def test_icl_eval(capfd: Any, mock_saved_model_path: Any): assert expected_results in out expected_results = '| model_name | default_average | language_understanding_lite |\n|:-------------|------------------:|------------------------------:|\n| tiny_mpt | 0 | 0 |' assert expected_results in out + + +@pytest.mark.gpu +def test_loader_eval(capfd: Any, mock_saved_model_path: Any, + tmp_path: pathlib.Path): + + c4_dataset_name = create_c4_dataset_xxsmall(tmp_path) + + # Use a training config that already has eval loader configured + test_cfg = gpt_tiny_cfg(c4_dataset_name, 'cpu') + + # define icl eval task + test_cfg.icl_tasks = om.ListConfig([ + om.DictConfig({ + 'label': + 'lambada_openai', + 'dataset_uri': + 'eval/local_data/language_understanding/lambada_openai_small.jsonl', + 'num_fewshot': [0], + 'icl_task_type': + 'language_modeling' + }) + ]) + + # convert the model from a training to eval model + model = test_cfg.pop('model') + eval_model = { + 'model_name': model.get('name'), + 'model': model, + 'load_path': mock_saved_model_path + } + + tokenizer = test_cfg.pop('tokenizer') + eval_model['tokenizer'] = tokenizer + test_cfg.models = [eval_model] + + # Set up multiple eval dataloaders + first_eval_loader = test_cfg.eval_loader + first_eval_loader.label = 'c4' + # Create second eval dataloader using the arxiv dataset. + second_eval_loader = copy.deepcopy(first_eval_loader) + arxiv_dataset_name = create_arxiv_dataset(tmp_path) + second_eval_loader.data_local = arxiv_dataset_name + second_eval_loader.label = 'arxiv' + test_cfg.eval_loader = om.OmegaConf.create( + [first_eval_loader, second_eval_loader]) + + test_cfg.max_duration = '1ba' + test_cfg.eval_interval = '1ba' + test_cfg.loggers = om.DictConfig({'inmemory': om.DictConfig({})}) + + trainers, eval_gauntlet_df = main(test_cfg) + + assert eval_gauntlet_df is None + assert len(trainers) == 1 # one per model + trainer = trainers[0] + + assert isinstance(trainer.logger.destinations, tuple) + + assert len(trainer.logger.destinations) > 0 + inmemorylogger = trainer.logger.destinations[ + 0] # pyright: ignore [reportGeneralTypeIssues] + assert isinstance(inmemorylogger, InMemoryLogger) + print(inmemorylogger.data.keys()) + + # Checks for first eval dataloader + assert 'metrics/eval/c4/LanguageCrossEntropy' in inmemorylogger.data.keys() + assert isinstance( + inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'], list) + assert len( + inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1]) > 0 + assert isinstance( + inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], tuple) + + # Checks for second eval dataloader + assert 'metrics/eval/arxiv/LanguageCrossEntropy' in inmemorylogger.data.keys( + ) + assert isinstance( + inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'], list) + assert len( + inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'][-1]) > 0 + assert isinstance( + inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'][-1], + tuple) diff --git a/tests/test_eval_inputs.py b/tests/test_eval_inputs.py index 9c7a130a9b..83104b62b7 100644 --- a/tests/test_eval_inputs.py +++ b/tests/test_eval_inputs.py @@ -57,6 +57,7 @@ def test_optional_mispelled_params_raise_warning(self, 'loggers', 'eval_gauntlet', 'fsdp_config', + 'eval_loader', ] old_cfg = copy.deepcopy(cfg) for param in optional_params: diff --git a/tests/test_train_inputs.py b/tests/test_train_inputs.py index bf90f48ef0..2ed1c9c239 100644 --- a/tests/test_train_inputs.py +++ b/tests/test_train_inputs.py @@ -103,7 +103,7 @@ def test_optional_mispelled_params_raise_warning(self, 'save_folder', 'fsdp_config', 'lora_config', - 'eval_loader_config', + 'eval_loader', 'icl_tasks_config', ] old_cfg = copy.deepcopy(cfg) diff --git a/tests/test_training.py b/tests/test_training.py index 8390834d1d..3cd2963100 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -3,9 +3,6 @@ import copy import os import pathlib -import shutil -import sys -from argparse import Namespace from typing import Any, Optional import pytest @@ -14,95 +11,9 @@ from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om -# Add repo root to path so we can import scripts and test it -repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) -sys.path.append(repo_dir) - -from scripts.data_prep.convert_dataset_hf import main as main_hf # noqa: E402 -from scripts.data_prep.convert_dataset_json import \ - main as main_json # noqa: E402 from scripts.train.train import main # noqa: E402 - - -def create_c4_dataset_xsmall(path: pathlib.Path) -> str: - """Creates a small mocked version of the C4 dataset.""" - c4_dir = os.path.join(path, f'my-copy-c4') - downloaded_split = 'val_xxsmall' - main_hf( - Namespace( - **{ - 'dataset': 'c4', - 'data_subset': 'en', - 'splits': [downloaded_split], - 'out_root': c4_dir, - 'compression': None, - 'concat_tokens': 2048, - 'tokenizer': 'EleutherAI/gpt-neox-20b', - 'tokenizer_kwargs': {}, - 'bos_text': '', - 'eos_text': '<|endoftext|>', - 'no_wrap': False, - 'num_workers': 8 - })) - - # copy the small downloaded_split to other c4 splits for mocking purposes - mocked_splits = ['train', 'val'] - for mocked_split in mocked_splits: - shutil.copytree(os.path.join(c4_dir, 'val_xxsmall'), - os.path.join(c4_dir, mocked_split)) - assert os.path.exists(c4_dir) - return c4_dir - - -def create_arxiv_dataset(path: pathlib.Path) -> str: - """Creates an arxiv dataset.""" - arxiv_dir = os.path.join(path, f'my-copy-arxiv') - downloaded_split = 'train' - - main_json( - Namespace( - **{ - 'path': 'data_prep/example_data/arxiv.jsonl', - 'out_root': arxiv_dir, - 'compression': None, - 'split': downloaded_split, - 'concat_tokens': None, - 'bos_text': None, - 'eos_text': None, - 'no_wrap': False, - 'num_workers': None - })) - - return arxiv_dir - - -def gpt_tiny_cfg(dataset_name: str, device: str): - """Create gpt tiny cfg.""" - conf_path: str = os.path.join(repo_dir, - 'scripts/train/yamls/pretrain/testing.yaml') - with open(conf_path) as f: - test_cfg = om.load(f) - assert isinstance(test_cfg, DictConfig) - - test_cfg.data_local = dataset_name - test_cfg.global_train_batch_size = 1 - test_cfg.device_eval_batch_size = 2 - test_cfg.device_train_microbatch_size = 1 - test_cfg.max_duration = '4ba' - test_cfg.eval_interval = '4ba' - test_cfg.run_name = 'gpt-mini-integration-test' - - test_cfg.model.n_layer = 2 - test_cfg.model.n_embd = 64 - - if device == 'cpu': - test_cfg.model.init_device = 'cpu' - test_cfg.fsdp_config = None - test_cfg.model.attn_config.attn_impl = 'torch' - test_cfg.model.loss_fn = 'torch_crossentropy' - test_cfg.precision = 'fp32' - - return test_cfg +from tests.data_utils import (create_arxiv_dataset, create_c4_dataset_xxsmall, + gpt_tiny_cfg) @pytest.fixture(autouse=False) @@ -122,7 +33,7 @@ def set_correct_cwd(): def test_train_gauntlet(averages: Optional[dict], set_correct_cwd: Any, tmp_path: pathlib.Path): """Test training run with a small dataset.""" - dataset_name = create_c4_dataset_xsmall(tmp_path) + dataset_name = create_c4_dataset_xxsmall(tmp_path) test_cfg = gpt_tiny_cfg(dataset_name, 'cpu') test_cfg.icl_tasks = ListConfig([ DictConfig({ @@ -201,7 +112,7 @@ def test_train_gauntlet(averages: Optional[dict], set_correct_cwd: Any, def test_train_multi_eval(set_correct_cwd: Any, tmp_path: pathlib.Path): """Test training run with multiple eval datasets.""" - c4_dataset_name = create_c4_dataset_xsmall(tmp_path) + c4_dataset_name = create_c4_dataset_xxsmall(tmp_path) test_cfg = gpt_tiny_cfg(c4_dataset_name, 'cpu') # Set up multiple eval dataloaders first_eval_loader = test_cfg.eval_loader From 22ae919c1d6b2b542278399586a1835e0e632bba Mon Sep 17 00:00:00 2001 From: Sam Havens Date: Thu, 30 Nov 2023 17:47:43 -0800 Subject: [PATCH 4/6] Support inputs_embeds (#687) * support inputs_embeds * update tests to test inputs_embeds * make iids optional inputs to fwd * remove check for both iids and inputs_embeds in MPTForCausalLM. It is checked in the base model, and it is actually a common practice to pass both during autoregressive generation. Embeds are used first, then once the kvcache is nonempty, iids are used instead * reorder kwargs * add more tests * fix device merge artifact in test_model.oy * fix generate test * yapf --- llmfoundry/models/mpt/modeling_mpt.py | 51 +++++++++-------- tests/test_model.py | 79 +++++++++++++++++++++++++-- 2 files changed, 101 insertions(+), 29 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 274c1b76e5..d6b23c04d0 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -368,7 +368,7 @@ def _apply_sequence_id(self, attn_bias: torch.Tensor, def forward( self, - input_ids: torch.LongTensor, + input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, attention_mask: Optional[torch.ByteTensor] = None, prefix_mask: Optional[torch.ByteTensor] = None, @@ -412,11 +412,6 @@ def forward( 'prefix_mask is a required argument when MPT is configured with prefix_lm=True.' ) - # Raise a not implemented error if input_embeds is not None (this is an arg in huggingface transformers and we need to support it for PEFT) - if inputs_embeds is not None: - raise NotImplementedError( - 'inputs_embeds is not implemented for MPT.') - if self.training: if self.attn_uses_sequence_id and sequence_id is None: raise ValueError( @@ -430,14 +425,25 @@ def forward( 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.' ) - S = input_ids.size(1) + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds.') + elif input_ids is not None: + S = input_ids.size(1) + x = self.wte(input_ids) + input_device = input_ids.device + elif inputs_embeds is not None: + S = inputs_embeds.size(1) + x = inputs_embeds + input_device = inputs_embeds.device + else: + raise ValueError('You must specify input_ids or inputs_embeds') assert ( S <= self.config.max_seq_len ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' rotary_emb_w_meta_info = None - x = self.wte(input_ids) if self.learned_pos_emb or self.rope: past_position = 0 if past_key_values is not None: @@ -467,7 +473,7 @@ def forward( past_position, S + past_position, dtype=torch.long, - device=input_ids.device, + device=input_device, ).unsqueeze(0) if attention_mask is not None: # adjust the position indices to account for padding tokens @@ -652,7 +658,7 @@ def get_decoder(self) -> MPTModel: def forward( self, - input_ids: torch.LongTensor, + input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, attention_mask: Optional[torch.ByteTensor] = None, prefix_mask: Optional[torch.ByteTensor] = None, @@ -669,11 +675,6 @@ def forward( use_cache = (use_cache if use_cache is not None else self.config.use_cache) - # if input_embeds is not none, raise a not implemented error - if inputs_embeds is not None: - raise NotImplementedError( - 'inputs_embeds has to be None (for hf/peft support).') - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.transformer( input_ids=input_ids, past_key_values=past_key_values, @@ -684,6 +685,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, + inputs_embeds=inputs_embeds, ) if self.lm_head is not None: @@ -773,10 +775,6 @@ def prepare_inputs_for_generation( inputs_embeds: Optional[torch.Tensor] = None, **kwargs: Any, ) -> Dict[str, Any]: - if inputs_embeds is not None: - raise NotImplementedError( - 'inputs_embeds is not implemented for MPT yet') - attention_mask = kwargs['attention_mask'].bool() if attention_mask[:, -1].sum() != attention_mask.shape[0]: raise NotImplementedError( @@ -787,6 +785,7 @@ def prepare_inputs_for_generation( else: sequence_id = None + # only last token for inputs_ids if past is defined in kwargs if past_key_values is not None: input_ids = input_ids[:, -1].unsqueeze(-1) @@ -800,14 +799,20 @@ def prepare_inputs_for_generation( else: prefix_mask = None - return { - 'input_ids': input_ids, + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + model_inputs = {'input_ids': input_ids} + + model_inputs.update({ 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), - } + }) + return model_inputs @staticmethod def _reorder_cache( @@ -898,7 +903,7 @@ def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: add_bidirectional_mask_if_missing(batch) # Note: prefix_mask is only used if model.prefix_lm is True return self.model( - input_ids=batch['input_ids'], + input_ids=batch.get('input_ids', None), attention_mask=batch.get('attention_mask', None), prefix_mask=batch.get('bidirectional_mask', None), sequence_id=batch.get('sequence_id', None), diff --git a/tests/test_model.py b/tests/test_model.py index 4d5b0a4dbc..acb2074ae9 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -5,7 +5,7 @@ import os import pathlib import warnings -from typing import Any, Dict, Union, cast +from typing import Any, Dict, List, Optional, Union, cast from unittest import mock import pytest @@ -94,13 +94,26 @@ def get_objs(conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml'): return test_cfg, model, optimizer -def gen_random_batch(batch_size: int, test_cfg: Union[DictConfig, ListConfig]): +def gen_random_batch(batch_size: int, + test_cfg: Union[DictConfig, ListConfig], + inputs: Optional[List[str]] = None): + # inputs can be [], ['input_ids'], ['input_ids', 'inputs_embeds'], and ['inputs_embeds'] + # default to only input ids + if inputs == None: + inputs = ['input_ids'] # generate input batch of random data, suitable for a Causal or Prefix LM batch = {} - batch['input_ids'] = torch.randint( - low=0, - high=test_cfg.model.vocab_size, - size=(batch_size, test_cfg.max_seq_len)).to(test_cfg.device) + for inp in inputs: + if inp == 'input_ids': + batch['input_ids'] = torch.randint( + low=0, + high=test_cfg.model.vocab_size, + size=(batch_size, test_cfg.max_seq_len)).to(test_cfg.device) + if inp == 'inputs_embeds': + batch['inputs_embeds'] = torch.randn( + batch_size, test_cfg.max_seq_len, + test_cfg.model.d_model).to(test_cfg.device) + batch['labels'] = torch.randint(low=0, high=test_cfg.model.vocab_size, size=(batch_size, test_cfg.max_seq_len)).to( @@ -150,6 +163,34 @@ def test_full_forward_and_backward(batch_size: int = 2): assert not torch.equal(original_params, updated_params) +def test_full_forward_and_backward_with_inputs_embeds(batch_size: int = 2): + test_cfg, model, optimizer = get_objs( + conf_path='scripts/train/yamls/pretrain/testing.yaml') + + batch = gen_random_batch(batch_size, test_cfg, inputs=['inputs_embeds']) + + model.train() + original_params = next(model.parameters()).clone().data + outputs = model(batch) + loss = model.loss(outputs, batch) + loss.backward() + optimizer.step() + updated_params = next(model.parameters()).clone().data + assert not torch.equal(original_params, updated_params) + + +@pytest.mark.parametrize('inputs', [[], ['input_ids', 'inputs_embeds']]) +def test_invalid_inputs_embeds_input_ids_combinations(inputs: List[str]): + test_cfg, model, _ = get_objs( + conf_path='scripts/train/yamls/pretrain/testing.yaml') + + batch = gen_random_batch(2, test_cfg, inputs=inputs) + + model.train() + with pytest.raises(ValueError): + _ = model(batch) + + def test_attention_mechanism(batch_size: int = 2): test_cfg, model, _ = get_objs( conf_path='scripts/train/yamls/pretrain/testing.yaml') @@ -825,6 +866,9 @@ def test_generate(attention_impl: str, precision: str, pos_emb_config: dict, no_padding_attention_mask = composer_device.tensor_to_device( no_padding_attention_mask) + # inputs_embeds + inputs_embeds = composer_device.tensor_to_device(torch.randn(2, 3, 128)) + # a single batch with different amounts of left padding in the input batched_input_ids = torch.tensor([[50256, 50256, 50256, 11274, 16390, 11], [50256, 50256, 16, 11274, 16390, 11]]) @@ -860,6 +904,29 @@ def test_generate(attention_impl: str, precision: str, pos_emb_config: dict, assert generation_with_no_padding[:, 3:].equal( generation_with_left_padding[:, 6:]) + # check that both/neither ids and embeds do not error + # note that we need to set the BOS token ID for generating from neither + _ = mpt.generate(input_ids=no_padding_input_ids, + inputs_embeds=inputs_embeds, + attention_mask=no_padding_attention_mask, + max_new_tokens=5, + use_cache=False) + _ = mpt.generate(input_ids=no_padding_input_ids, + inputs_embeds=inputs_embeds, + attention_mask=no_padding_attention_mask, + max_new_tokens=5, + use_cache=True) + _ = mpt.generate(input_ids=None, + inputs_embeds=None, + max_new_tokens=5, + use_cache=False, + bos_token_id=50256) + _ = mpt.generate(input_ids=None, + inputs_embeds=None, + max_new_tokens=5, + use_cache=True, + bos_token_id=50256) + @pytest.mark.gpu @pytest.mark.parametrize('world_size', [1, 2]) From 9cf99b7457a6ed0e199a56785dad697bd4a09a58 Mon Sep 17 00:00:00 2001 From: Anna Date: Thu, 30 Nov 2023 19:38:01 -0800 Subject: [PATCH 5/6] Better error message when test does not complete (#769) --- .github/mcp/mcp_pytest.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/mcp/mcp_pytest.py b/.github/mcp/mcp_pytest.py index 5f0aaa147b..b6d74880c8 100644 --- a/.github/mcp/mcp_pytest.py +++ b/.github/mcp/mcp_pytest.py @@ -130,7 +130,7 @@ print(line, end='') print('[GHA] Run completed. Waiting for run to finish...') - run = wait_for_run_status(run, status='completed') + run = wait_for_run_status(run, status=RunStatus.COMPLETED) - # Fail if command exited with non-zero exit code or timed out - assert run.status == RunStatus.COMPLETED + # Fail if command exited with non-zero exit code or timed out (didn't reach COMPLETED) + assert run.status == RunStatus.COMPLETED, f'Run did not complete: {run.status} ({run.reason})' From 32dc3bd85134b0362b234abb03d7ebae04bb5ac6 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Fri, 1 Dec 2023 09:39:54 -0800 Subject: [PATCH 6/6] Add codeowners (#770) * add codeowners * precommit --- .github/CODEOWNERS | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 .github/CODEOWNERS diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000000..bbdd4259cd --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,8 @@ +# Require admin approval to modify all files in the root of the repository +# This includes setup.py, the README, and the CODEOWNERS file itself! +/* @mosaicml/composer-team-admins + +# Require admin approval to change the CI build configuration +# All CI Changes should be reviewed for security +/.ci/ @mosaicml/composer-team-admins +/.github/ @mosaicml/composer-team-admins