diff --git a/.gitignore b/.gitignore index d041a25c22..1dd80a8b6c 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ my-copy-c4*/ my-copy-arxiv*/ *.jsonl* +!tests/eval/local_data/*.jsonl # WandB wandb/ diff --git a/llmfoundry/__init__.py b/llmfoundry/__init__.py index 922f738e9a..54a55d6e97 100644 --- a/llmfoundry/__init__.py +++ b/llmfoundry/__init__.py @@ -28,7 +28,7 @@ MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn, scaled_multihead_dot_product_attention) from llmfoundry.models.layers.blocks import MPTBlock -from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn +from llmfoundry.models.layers.ffn import MPTMLP from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig, MPTForCausalLM, MPTModel, MPTPreTrainedModel) from llmfoundry.tokenizers import TiktokenTokenizerWrapper @@ -37,9 +37,7 @@ 'build_finetuning_dataloader', 'Seq2SeqFinetuningCollator', 'MPTBlock', - 'FFN_CLASS_REGISTRY', 'MPTMLP', - 'build_ffn', 'MPTConfig', 'MPTPreTrainedModel', 'MPTModel', diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 42b15e4d6e..4906cea151 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -77,7 +77,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: _ALLOWED_MESSAGES_KEYS = {'messages'} _ALLOWED_ROLE_KEYS = {'role'} _ALLOWED_CONTENT_KEYS = {'content'} -_ALLOWED_ROLES = {'user', 'assistant', 'system'} +_ALLOWED_ROLES = {'user', 'assistant', 'system', 'tool'} _ALLOWED_LAST_MESSAGE_ROLES = {'assistant'} DOWNLOADED_FT_DATASETS_DIRPATH = os.path.abspath( os.path.join(os.path.realpath(__file__), os.pardir, os.pardir, os.pardir, @@ -217,7 +217,7 @@ def slice_out_last_turn( if conversation_through_previous_turn != prompt_with_history[:len( conversation_through_previous_turn)]: raise ValueError( - f'The prompt_with_histry must start with the conversation through the previous turn. {conversation_through_previous_turn=}, {prompt_with_history=}' + f'The prompt_with_history must start with the conversation through the previous turn. {conversation_through_previous_turn=}, {prompt_with_history=}' ) if prompt_with_history != full_conversation[:len(prompt_with_history)]: raise ValueError( @@ -624,9 +624,8 @@ def print_registered_tasks(self) -> None: log.info('\n'.join(tasks)) def get_preprocessing_fn_from_dict( - self, - mapping: Dict[str, - str]) -> Callable[[Dict[str, Any]], Dict[str, str]]: + self, mapping: Dict[str, + str]) -> Callable[[Dict[str, Any]], Example]: """Get a preprocessing function from a dictionary. The dictionary maps column names in the dataset to "prompt" and "response". @@ -662,7 +661,7 @@ def get_preprocessing_fn_from_str( self, preprocessor: Optional[str], dataset_name: Optional[str] = None - ) -> Optional[Callable[[Dict[str, Any]], Dict[str, str]]]: + ) -> Optional[Callable[[Dict[str, Any]], Example]]: """Get a preprocessing function from a string. String can be either a registered function or an import path. @@ -710,7 +709,7 @@ def get_preprocessing_fn_from_str( def build_from_hf( self, dataset_name: str, split: str, safe_load: bool, max_seq_len: int, - preprocessing_fn: Optional[Callable[[dict[str, Any]], dict[str, str]]], + preprocessing_fn: Optional[Callable[[dict[str, Any]], Example]], tokenizer: PreTrainedTokenizerBase, target_prompts: str, target_responses: str, decoder_only_format: bool, hf_kwargs: Dict[str, Any] @@ -793,7 +792,8 @@ def build_from_hf( def dataset_mapper(example: Dict): if preprocessing_fn is not None: - example = preprocessing_fn(example) + return tokenize_formatted_example(preprocessing_fn(example), + tokenizer) return tokenize_formatted_example(example, tokenizer) detected_cpu_count = os.cpu_count() or 1 @@ -857,7 +857,7 @@ def build_from_streaming(self, *args: Any, @dataset_constructor.register('tatsu-lab/alpaca') -def alpaca_preprocessing_function(inp: Dict) -> Dict[str, str]: +def alpaca_preprocessing_function(inp: Dict) -> PromptResponseDict: """Split out prompt/response from text.""" try: prompt, response = inp['text'].split('### Response:') @@ -869,7 +869,7 @@ def alpaca_preprocessing_function(inp: Dict) -> Dict[str, str]: @dataset_constructor.register('HuggingFaceH4/databricks_dolly_15k') -def dolly_preprocessing_function(inp: Dict) -> Dict[str, str]: +def dolly_preprocessing_function(inp: Dict) -> PromptResponseDict: """Format the text string.""" PROMPT_FORMAT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n' try: @@ -885,7 +885,7 @@ def dolly_preprocessing_function(inp: Dict) -> Dict[str, str]: @dataset_constructor.register('bigscience/P3') -def p3_preprocessing_function(inp: Dict) -> Dict[str, str]: +def p3_preprocessing_function(inp: Dict) -> PromptResponseDict: """Format the already-split example.""" return { 'prompt': inp['inputs'] + ':', @@ -895,7 +895,7 @@ def p3_preprocessing_function(inp: Dict) -> Dict[str, str]: # Muennighoff's P3 and flan datasets share a similar convention @dataset_constructor.register('Muennighoff/P3', 'Muennighoff/flan') -def muennighoff_tokenize_function(inp: Dict) -> Dict[str, str]: +def muennighoff_tokenize_function(inp: Dict) -> PromptResponseDict: """Format the already-split example.""" try: prompt: str = inp['inputs'] @@ -908,3 +908,22 @@ def muennighoff_tokenize_function(inp: Dict) -> Dict[str, str]: except Exception as e: raise UnableToProcessPromptResponseError(inp) from e return {'prompt': prompt, 'response': response} + + +@dataset_constructor.register('teknium/OpenHermes-2.5') +def shareGPT_format_preprocessor(inp: Dict) -> ChatFormattedDict: + """Convert from ShareGPT format to our chat format.""" + role_map = { + 'human': 'user', + 'gpt': 'assistant', + } + try: + conversation = inp['conversations'] + messages: List[Dict[str, str]] = [] + for message in conversation: + role: str = role_map.get(message['from'], message['from']) + content: str = message['value'] + messages.append({'role': role, 'content': content}) + except Exception as e: + raise UnableToProcessPromptResponseError(inp) from e + return {'messages': messages} diff --git a/llmfoundry/eval/__init__.py b/llmfoundry/eval/__init__.py new file mode 100644 index 0000000000..80950cb7b4 --- /dev/null +++ b/llmfoundry/eval/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/llmfoundry/eval/datasets/__init__.py b/llmfoundry/eval/datasets/__init__.py new file mode 100644 index 0000000000..e6a8b5222d --- /dev/null +++ b/llmfoundry/eval/datasets/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Natively supported in-context learning evaluation datasets.""" + +from llmfoundry.eval.datasets.in_context_learning_evaluation import ( + InContextLearningCodeEvalDataset, InContextLearningDataset, + InContextLearningGenerationTaskWithAnswersDataset, + InContextLearningLMTaskDataset, InContextLearningMultipleChoiceTaskDataset, + InContextLearningSchemaTaskDataset, get_icl_task_dataloader) +from llmfoundry.eval.datasets.utils import (get_continuation_span, + get_fewshot_sample_idxs, + make_padded_input, strip_data, + tokenizer_needs_prefix_space, + trim_context) + +__all__ = [ + 'InContextLearningDataset', + 'InContextLearningGenerationTaskWithAnswersDataset', + 'InContextLearningLMTaskDataset', + 'InContextLearningCodeEvalDataset', + 'InContextLearningMultipleChoiceTaskDataset', + 'InContextLearningSchemaTaskDataset', + 'get_icl_task_dataloader', + 'strip_data', + 'tokenizer_needs_prefix_space', + 'trim_context', + 'get_continuation_span', + 'get_fewshot_sample_idxs', + 'make_padded_input', +] diff --git a/llmfoundry/eval/datasets/in_context_learning_evaluation.py b/llmfoundry/eval/datasets/in_context_learning_evaluation.py new file mode 100644 index 0000000000..8f317f60b8 --- /dev/null +++ b/llmfoundry/eval/datasets/in_context_learning_evaluation.py @@ -0,0 +1,1791 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import copy +import json +import logging +import os +import random +import warnings +from typing import Any, Dict, Iterable, List, Optional, Sequence, Union + +import torch +import transformers +from composer.core import DataSpec +from composer.core.data_spec import _default_split_batch, _split_list +from composer.datasets.utils import stop_sequences_criteria +from composer.utils import MissingConditionalImportError, dist, get_file +from datasets import Dataset as HFDataset +from datasets import IterableDataset, load_dataset +from torch.utils.data import DataLoader, Dataset + +from llmfoundry.eval.datasets.utils import (convert_tokens_to_tensors, + get_continuation_span, + get_fewshot_sample_idxs, + make_padded_input, strip_data, + tokenizer_needs_prefix_space, + trim_context) +from llmfoundry.utils.warnings import VersionedDeprecationWarning + +log = logging.getLogger(__name__) + +# Allow models to have slightly more tokens than were used in the most verbose CoT in the dataset +_MAX_ANSWER_BUFFER_LENGTH = 10 + +__all__ = [ + 'InContextLearningLMTaskDataset', + 'InContextLearningMultipleChoiceTaskDataset', + 'InContextLearningSchemaTaskDataset', + 'InContextLearningCodeEvalDataset', + 'InContextLearningGenerationTaskWithAnswersDataset', + 'get_icl_task_dataloader', +] + + +class InContextLearningDataset(Dataset): + r"""A base dataset that constructs batches for in-context learning task. + + evaluations. The dataset format is expected to be a local jsonl file, a + cloud link to a jsonl file, or a Hugging Face dataset link. 'context' refers + to the input a model will recieve before generating an output. For example, + the question in question answering tasks, the preceding text in a language + modeling task, or the document and question regarding the document in a + document understanding task. 'example' refers to a loaded dictionary, + generally containing a context, an answer, and any other information needed + to run the task. 'answer' refers to the desired output of the model. + + When creating a new ICL Dataset, it is likely that you will need to reimplement the following methods: + + - construct_context(): Takes a single example dictionary and formulates the context as a string for that eval question. + - get_answer_from_example(): Takes a single example dictionary and formulates the correct, ground truth answer as a string. + - tokenize_example(): Tokenizes the example and adds any extra content from the original dictionary that needs to be passed downstream. + - read_dataset(): Loads the dataset and does basic parsing. If additional parsing must be done, this is a good place to do so (See InContextLearningGenerationTaskWithAnswersDataset.read_dataset()) + + Additionally, base_batch and batch_mapping must be defined. + + - base_batch (Dict): The base dictionary that the dataset will use to construct a batch. This should contain static values, like generation_kwargs or mode, + and empty lists for values that will need to be accumulated from each example. + NOTE: Sometimes you will need to set base_batch directly after the init call, e.g. in order to use class variables + like self.pad_tok_id or self.max_answer_length. If you manually set generation_kwargs this way, you'll need to call self.update_generation_kwargs() + after setting self.base_batch. + - batch_mapping (Dict): A mapping with keys that are keys in the batch and values that are columns in the loaded dataset. + collate_fn will use this mapping to create batches from self.dataset. + + Args: + dataset_uri (str): A local path, a remote path beginning with ``s3://`` or another backend, or a HuggingFace dataset uri prepended with ``hf://``. + Alternate backends must be supported by :meth:`composer.utils.maybe_create_object_store_from_uri`. + A local dataset must consist of rows of JSON data points with task dependent fields. + The default keys expected are "context" and "answer". + tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer used to map between strings and token ids. + max_seq_len (int): The maximum sequence length supported by the model. + pad_tok_id (int): The special token used for padding batches. + num_fewshot (int): The number of complete fewshot examples to prepend before each test example. These are not identical across examples. + fewshot_random_seed (int): Random seed to use for fewshot sampling. + prompt_string (str): Prompt string to put once before all fewshot examples/test examples (e.g. 'Translate english to french.'). + example_delimiter (str): Separator inserted before (context, answer) pairs (e.g. '\\n') for fewshot sampling and prompting. + continuation_delimiter: (str): Separator inserted between context and answer in each example (e.g. '\\nA: '). + destination_path (str): Temporary path to store downloaded datasets. + prelimiter (str): Text to be prepended before each context, including few shot examples (e.g. "Question: "). + context_key (str): The key in the loaded dataset that contains the context. + answer_key (str): The key in the loaded dataset that contains the answer. + strip_dataset (bool): Boolean for whether to strip whitespace from data. Trailing whitespace can cause degenerative outputs, + so unless whitespace should be preserved (for example in code), this should be set to True. + padding_side (str): Side of the content and answer on which to apply padding. Can be either 'right' or 'left'. + tokenize_labels (bool): Whether or not the labels should be tokenized. Generally determined by which metric a dataset uses. + padding_size (int): The final size of the tensor after padding. Defaults to max_sequence_length. + base_batch (Dict): The base dictionary upon which a batch is created. See above for more details. + base_mapping (Dict): A mapping of batch keys to dataset columns, used to create batches. See above for more details. + hf_loading_vars (Dict): A dictionary containing keyword arguments to be passed into `load_dataset` if dataset is being pulled from HF. + hf_parsing_map (Dict): A dictionary containing a mapping from HF columns to ICL dataset keys. The dictionary should be formatted {icl_key:[hf_key1, hf_key1]}. + Column contents will be concatenated with ' ' seperating them. If not included, will load the columns already present in the HF dataset. + generation_kwargs (Dict): A dictionary containing keyword arguments to be passed along to the model's generate function. + static_keys (List): A list of the key values which will be broadcast across a batch (e.g. it is the same for each batch element). + list_keys (List): A list of the batch keys whose values are lists which will be split using list methods during calls to split_batch. + tensor_keys (List): A list of the batch keys whose values are tensors which will be split using tensor methods during calls to split_batch. + """ + + def __init__( + self, + dataset_uri: str, + tokenizer: transformers.PreTrainedTokenizerBase, + max_seq_len: int, + pad_tok_id: int, + num_fewshot: int, + fewshot_random_seed: int, + prompt_string: str, + example_delimiter: str, + continuation_delimiter: str, + destination_path: str, + prelimiter: str = '', + context_key: str = 'context', + answer_key: str = 'answer', + strip_dataset: bool = True, + padding_side: str = 'right', + tokenize_labels: bool = True, + padding_size: Optional[int] = None, + base_batch: Optional[Dict] = None, + batch_mapping: Optional[Dict] = None, + hf_loading_vars: Optional[Dict] = None, + hf_parsing_map: Optional[Dict] = None, + generation_kwargs: Optional[Dict] = None, + static_keys: Optional[List] = None, + list_keys: Optional[List] = None, + tensor_keys: Optional[List] = None, + ): + self.tokenizer = tokenizer + self.prefix_space = tokenizer_needs_prefix_space(self.tokenizer) + + self.max_seq_len = max_seq_len + self.pad_tok_id = pad_tok_id + self.num_fewshot = num_fewshot + self.padding_side = padding_side + self.padding_size = padding_size if padding_size else self.max_seq_len + self.prelimiter = prelimiter + self.example_delimiter = example_delimiter + self.continuation_delimiter = continuation_delimiter + self.context_key = context_key + self.answer_key = answer_key + self.tokenize_labels = tokenize_labels + self.batch_mapping = batch_mapping or {} + self.base_batch = base_batch or {} + if generation_kwargs: + self.update_generation_kwargs(generation_kwargs) + + self.static_keys = static_keys + self.list_keys = list_keys + self.tensor_keys = tensor_keys + + hf_loading_vars = hf_loading_vars or {} + self.dataset: HFDataset = self.read_dataset(dataset_uri, + destination_path, + hf_loading_vars, + hf_parsing_map) + self.strip_data = strip_dataset + if self.strip_data: + self.dataset = self.dataset.map(strip_data) + + fewshot_rng = random.Random(fewshot_random_seed) + self.dataset: HFDataset = self.dataset.map( + self._prep_example, + with_indices=True, + fn_kwargs={ + 'num_fewshot': num_fewshot, + 'prompt_string': prompt_string, + 'fewshot_rng': fewshot_rng, + }, + ) + + def __getitem__(self, index: int) -> Dict: + return self.dataset[index] + + def __len__(self) -> int: + return len(self.dataset) + + def get_num_samples_in_batch(self, batch: Dict) -> int: + return batch['input_ids'].shape[0] + + def update_generation_kwargs(self, generation_kwargs: Dict) -> None: + r"""Updates self.base_batch with the passed in generation_kwargs. + + This must be run after self.base_batch is set (for example, if + self.base_batch is set after __init__() is run, likely because + base_batch needs a class variable like self.pad_tok_id or + self.max_answer_length). + + Args: + generation_kwargs (Dict): Keyword arguments that be written into base_batch['generation_kwargs'] + """ + if generation_kwargs: + if 'generation_kwargs' not in self.base_batch: + self.base_batch['generation_kwargs'] = {} + self.base_batch['generation_kwargs'].update(generation_kwargs) + + def read_dataset( + self, + dataset_uri: str, + destination_path: str, + hf_loading_vars: Optional[Dict[str, Any]] = None, + hf_parsing_map: Optional[Dict[str, Any]] = None) -> 'HFDataset': + """Reads a dataset and handles parsing it from HuggingFace. + + Args: + dataset_uri (str): A local path, a remote path beginning with ``s3://`` or another backend, or a HuggingFace dataset uri. + Alternate backends must be supported by :meth:`composer.utils.maybe_create_object_store_from_uri`. + destination_path (str): A local path where the data will be stored + hf_loading_vars (Dict): If parsing from HuggingFace, keyword args that will be passed into load_dataset + hf_parsing_map (Dict): Dictionary in the form of {icl_key: [hf_col1, hf_col2]} that will map one or more hf columns, in order, to ICL dataset columns + + Returns: + dataset: A loaded HF dataset + """ + from datasets import \ + Dataset as HFDataset # pyright: ignore[reportGeneralTypeIssues] + from datasets import \ + load_dataset # pyright: ignore[reportGeneralTypeIssues] + if 'hf://' in dataset_uri: + dataset_uri = dataset_uri.replace('hf://', '') + if hf_loading_vars is None: + hf_loading_vars = {} + dataset = load_dataset(dataset_uri, **hf_loading_vars) + if hf_parsing_map: + dataset_parsing_func = lambda example: { + k: ' '.join([str(example[col]) for col in v]) + for k, v in hf_parsing_map. + items( # pyright: ignore[reportOptionalMemberAccess] + ) + } + assert isinstance(dataset, HFDataset) + dataset = dataset.map(dataset_parsing_func, + remove_columns=dataset.column_names) + else: + with dist.local_rank_zero_download_and_wait(destination_path): + if dist.get_local_rank() == 0: + get_file(dataset_uri, destination_path, overwrite=True) + dataset = load_dataset('json', + data_files=destination_path, + split='train', + streaming=False) + assert isinstance(dataset, HFDataset) + return dataset + + def _generate_few_shot_prompt( + self, + num_fewshot: int, + example_idx: int, + preamble: str, + fewshot_rng: random.Random, + ) -> str: + """Formats the fewshot prompt for test example `example_idx`. + + Randomly selects `num_fewshot` samples from the dataset (excluding the example at `example_idx`) and constructs + contextes with answers appended. + + Returns the formatted prompt_string + concatenated list of formatted few shot examples as a string. + + Args: + num_fewshot (int): Number of examples to prepend + example_idx (int): Current example idx + preamble (str): Text to occur at the beginning of the task. Generally instructions or a prompt. + fewshot_rng (random.Random): Seeded sampler to chose samples with + + Returns: + str: The original preamble with num_fewshot examples appended + """ + few_shot_text = preamble + + if num_fewshot > 0: + fewshot_idxs = get_fewshot_sample_idxs( + len(self.dataset), + num_fewshot, + example_idx, + fewshot_rng, + ) + for fewshot_idx in fewshot_idxs: + ctxt = self.construct_context( + self.dataset[fewshot_idx], + few_shot_text, + add_answer=True, + ) + few_shot_text += ctxt + + return few_shot_text + + def construct_context(self, + example: Dict, + preceding_text: str = '', + add_answer: bool = False) -> str: + """Takes an example and constructs a context, i.e. the input the model. + + reads for this example. Optionally adds the correct answer (for fewshot + examples) and handles example delimiters. + + Args: + example (Dict): The example from which to construct the context + preceding_text (str): Any preceding text, used as a check for prepending self.example_delimiter + add_answer (bool): Bool for whether or not to add the answer on the end of the context (e.g. for fewshot examples) + + Returns: + str: The constructed context. The default output context is + formatted as follows: f'{self.prelimiter}{example[self.context_key]}{self.continuation_delimiter}' + """ + ctxt = example[self.context_key] + ctxt = f'{self.prelimiter}{ctxt}' + if len(preceding_text) > 0: + ctxt = f'{self.example_delimiter}{ctxt}' + ctxt = f'{ctxt}{self.continuation_delimiter}' + if add_answer: + ctxt = f'{ctxt}{self.get_answer_from_example(example, in_context=add_answer)}' + return ctxt + + def get_answer_from_example(self, + example: Dict[str, Any], + in_context: bool = False) -> str: + """Returns the answer from the example. + + Args: + example (Dict): The example from which to retrieve the answer + + Returns: + str: The answer in the example + """ + cont = example[self.answer_key] + if self.prefix_space and not cont.startswith(' ') and not in_context: + cont = f' {cont}' + return cont + + def _fix_eos_on_preamble(self, input_ids: List[int]) -> List[int]: + """If the input_ids is empty then input_ids will be a 0-length List. + + unless the tokenizer adds special tokens to empty strings (e.g. OPT + tokenizer). If there is an EOS token added, we need to remove it so it + is not in the middle of the prompt, as the specific eval question's + prompt will follow the input_ids. + + Args: + input_ids (List): The tokenized input + + Returns: + input_ids: The tokenized input conditionally edited + """ + if (self.tokenizer.eos_token_id is not None and len(input_ids) > 1 and + input_ids[-1] == self.tokenizer.eos_token_id): + input_ids = input_ids[:-1] + return input_ids + + def tokenize_example(self, prompt_and_fewshot: str, ctxt: str, + example: Dict) -> Dict[str, Any]: + """Runs text through the tokenizer and handle special cases. + + Args: + prompt_and_fewshot (str): The collection of the prompt and fewshot examples that belongs before the example's context + ctxt (str): The specific example's derrived context + example (Dict): The example as a dictionary. Used for additional processing in inherited classes. + + Returns: + Dict: Dictionary with the tokenized data + """ + tokenized_example = {} + # Always add special tokens to preamble + preamble = self.tokenizer(prompt_and_fewshot)['input_ids'] + assert isinstance(preamble, list) + preamble = self._fix_eos_on_preamble(preamble) + if self.strip_data: + # rstrip context because a prompt ending in a space results in degenerate output + ctxt = ctxt.rstrip() + # Never add special tokens to context + tokenized_context = self.tokenizer( + ctxt, add_special_tokens=False)['input_ids'] + assert isinstance(preamble, list) + assert isinstance(tokenized_context, list) + + tokenized_context = preamble + tokenized_context + + if self.tokenize_labels: + # Never add special tokens to answer + tokenized_answer = self.tokenizer( + self.get_answer_from_example(example), + add_special_tokens=False)['input_ids'] + assert isinstance(tokenized_answer, list) + trimmed_context = trim_context(tokenized_context, tokenized_answer, + self.padding_size) + assert isinstance(trimmed_context, list) + continuation_indices = get_continuation_span( + trimmed_context, tokenized_answer) + padded_context = make_padded_input(trimmed_context, + tokenized_answer, + self.padding_size, + self.pad_tok_id, + self.padding_side) + + tokenized_example[self.context_key] = padded_context + tokenized_example[self.answer_key] = tokenized_answer + tokenized_example['continuation_indices'] = continuation_indices + else: + assert isinstance(tokenized_context, list) + trimmed_context = trim_context( + tokenized_context, + [], + self.padding_size, + ) + assert isinstance(trimmed_context, list) + padded_context = make_padded_input(trimmed_context, [], + self.padding_size, + self.pad_tok_id, + self.padding_side) + + tokenized_example[self.context_key] = padded_context + tokenized_example[self.answer_key] = self.get_answer_from_example( + example) + + return tokenized_example + + def _prep_example( + self, + example: Dict, + example_idx: int, + num_fewshot: int, + prompt_string: str, + fewshot_rng: random.Random, + ) -> Dict[str, Any]: + """Prepares a single example from a HF Dataset into tokenized format. + + with prompt and fewshot examples. + + Each task consists of a context and a continuation as well as an optional prompt and optional list of + example context/continuation pairs which precede the test context/continuation pair. + + Args: + example (Dict): A Dictionary from the hf dataset + example_idx (int): The index of example + num_fewshot (int): Number of examples context/continuation pairs to prepend to the test pair + prompt_string (str): The prompt to prepend to all inputs + fewshot_rng (random.Random): Random number generator to use for fewshot sampling + + Returns: + Dict: Contains a dictionary with the tokenized data + """ + prompt_and_fewshot = self._generate_few_shot_prompt( + num_fewshot, example_idx, prompt_string, fewshot_rng) + ctxt = self.construct_context(example, + prompt_and_fewshot, + add_answer=False) + tokenized_example = self.tokenize_example(prompt_and_fewshot, ctxt, + example) + return tokenized_example + + def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """The function that the dataloader uses to accumulate data into. + + batches. + + Args: + data (List): List of tokenized datapoints (dicts returned by self._tokenize_example) + + Returns: + Dict: Dictionary for a single batch + """ + batch = copy.deepcopy(self.base_batch) + for data_pair in data: + for batch_key, data_key in self.batch_mapping.items(): + batch[batch_key].append(data_pair[data_key]) + if 'continuation_indices' in data_pair: + batch['continuation_indices'].append( + data_pair['continuation_indices']) + + batch = convert_tokens_to_tensors(batch, self.tokenize_labels) + batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id) + return batch + + def split_batch(self, batch: Any, + microbatch_size: Union[int, float]) -> Sequence[Any]: + """Handling for certain specialty columns that must be split into. + + batches in different formats. + + Args: + batch (Dict): Batch of data + microbatch_size (int | float): Size of microbatches + + Returns: + List: List of chunked batches + """ + # Don't split kwargs that don't change + # Normally split torch tensors + # List split lists of strings + if isinstance(microbatch_size, float): + raise ValueError( + 'split_batch does not support floating point microbatch_size.') + chunked = {} + for k, v in batch.items(): + if k in self.static_keys: + # Defer broadcasting until we know num_chunks + pass + elif k in self.list_keys: + chunked[k] = _split_list(v, microbatch_size) + elif k in self.tensor_keys: + chunked[k] = _default_split_batch(v, microbatch_size) + else: + raise ValueError(f'Unexpected key {k} in batch splitting') + num_chunks = len(chunked['input_ids']) + for k, v in batch.items(): + if k in self.static_keys: + chunked[k] = [v] * num_chunks + + batched_list = [ + {k: v[idx] for k, v in chunked.items()} for idx in range(num_chunks) + ] + return batched_list + + +class InContextLearningGenerationTaskWithAnswersDataset(InContextLearningDataset + ): + """A dataset that constructs batches for in-context learning generation. + + tasks with answers. Generation tasks evaluate a model's ability to + generate responses and score them against a set of gold-standard answers. + + The input format is expected to be a jsonl file with the following fields: + - context: The question + - answer: The preferred answer to the question + - aliases: A list of aliases for the answer + + See InContextLearningDataset for more details. + + Additional Args: + cot_delimiter (str): Delimiter to place between the chain of thought and continuations. + early_stopping_criteria (Optional[List[str]]): Optional strings to trigger early stopping. + do_normalization (bool): Flag indicating whether to normalize generations before providing output. + """ + + def __init__(self, + cot_delimiter: str = '', + early_stopping_criteria: Optional[List[str]] = None, + do_normalization: bool = True, + *args: Any, + **kwargs: Any): + if kwargs['tokenizer'].eos_token_id is None: + raise ValueError( + '`InContextLearningGenerationTaskWithAnswersDataset` tokenizer must have non-null `eos_token_id`' + ) + self.cot_delimiter = cot_delimiter + self.has_cot = False + self.max_answer_length = 0 + static_keys = [ + 'mode', 'cot_delimiter', 'generation_kwargs', 'do_normalization', + 'stopping_criteria' + ] + tensor_keys = ['input_ids', 'attention_mask'] + list_keys = ['labels'] + super().__init__(padding_side='left', + tokenize_labels=False, + static_keys=static_keys, + list_keys=list_keys, + tensor_keys=tensor_keys, + *args, + **kwargs) + # NOTE: set these after init call because they take class vars + self.early_stopping_criteria = early_stopping_criteria + self.base_batch = { + 'input_ids': [], + 'mode': 'generate', + 'labels': [], + 'cot_delimiter': self.cot_delimiter, + 'stopping_criteria': early_stopping_criteria, + 'do_normalization': do_normalization, + 'generation_kwargs': { + 'pad_token_id': self.pad_tok_id, + 'use_cache': True, + 'eos_token_id': self.tokenizer.eos_token_id, + 'max_new_tokens': max(self.max_answer_length, 1) + }, + } + self.batch_mapping = { + 'input_ids': self.context_key, + 'labels': 'aliases', + } + if 'generation_kwargs' in kwargs: + self.update_generation_kwargs(kwargs['generation_kwargs']) + + def read_dataset( + self, + dataset_uri: str, + destination_path: str, + hf_loading_vars: Dict, + hf_parsing_map: Dict, + ) -> 'HFDataset': + dataset = super().read_dataset(dataset_uri, destination_path, + hf_loading_vars, hf_parsing_map) + self.has_cot = 'chain_of_thought' in dataset.features + dataset = dataset.map( + lambda examples: { + 'context': + examples['context'], + 'answer': + examples['answer'], + 'aliases': + set([examples['answer']] + examples.get('aliases', [])), + 'chain_of_thought': + examples.get('chain_of_thought', ''), + }) + self.max_answer_length = self._get_max_answer_length(dataset) + # NOTE: This is the only time we use the class variable padding_size. + if self.max_seq_len < self.max_answer_length: + log.warning(f'`max_seq_len` {self.max_seq_len} was less than `max_answer_len`: {self.max_answer_length}' \ + + ' setting `max_seq_len`=`max_answer_len`') + self.max_seq_len = self.max_answer_length + self.padding_size = self.max_seq_len - self.max_answer_length + return dataset + + def get_answer_from_example(self, + example: Dict, + in_context: bool = False) -> str: + """Returns the answer from the example. Applies chain of thought if. + + self.has_cot is marked as true. + + Args: + example (Dict): The example from which to retrieve the answer + + Returns: + str: The answer in from the example with chain of thought and delimiter if needed + """ + if self.has_cot: + return f'{example["chain_of_thought"]}{self.cot_delimiter}{example[self.answer_key]}' + else: + return example[self.answer_key] + + def tokenize_example(self, prompt_and_fewshot: str, ctxt: str, + example: Dict) -> Dict[str, Any]: + """Run text through the tokenizer and handle special cases. + + Args: + prompt_and_fewshot (str): The collection of the prompt and fewshot examples that belongs before the example's context + ctx (str): The specific example's derrived context + example (Dict): The example as a dictionary. + + Returns: + Dict: Dictionary with the tokenized data + """ + tokenized_example = super().tokenize_example(prompt_and_fewshot, ctxt, + example) + tokenized_example['aliases'] = list(example.get('aliases', [])) + return tokenized_example + + def _get_max_answer_length(self, dataset: Iterable[dict]) -> int: + """Loops over the dataset and finds the longest answer length. + + Returns: + int: The maximum answer length with an additional buffer of 10 if chain of thought is present + """ + max_answer_length = 0 + for example in dataset: + all_answers = [example[self.answer_key]] + list( + example.get('aliases', [])) + for answer in all_answers: + if self.has_cot: + response = ( + f'{example["chain_of_thought"]}{self.cot_delimiter}{answer}' + ) + else: + response = answer + tokenized_repsonse = self.tokenizer(response)['input_ids'] + assert isinstance(tokenized_repsonse, list) + max_answer_length = max(max_answer_length, + len(tokenized_repsonse)) + max_answer_length = max_answer_length + ( + _MAX_ANSWER_BUFFER_LENGTH if len(self.cot_delimiter) > 0 else 0) + return max_answer_length + + def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + batch = super().collate_fn(data) + batch_size = batch['input_ids'].shape[0] + stopping_criteria = None + if self.early_stopping_criteria: + if stop_sequences_criteria is None: # pyright: ignore [reportUnnecessaryComparison] + raise MissingConditionalImportError( + extra_deps_group='nlp', + conda_package='transformers', + conda_channel='conda-forge') + stopping_criteria = stop_sequences_criteria( + self.tokenizer, self.early_stopping_criteria, batch_size) + batch['generation_kwargs']['stopping_criteria'] = stopping_criteria + return batch + + +class InContextLearningLMTaskDataset(InContextLearningDataset): + """A dataset that constructs batches for in-context learning language. + + modeling evaluation. Language modeling tasks test a model's ability to + properly predict tokens based on preceding tokens. + + The input format is expected to be a jsonl file with the following fields: + - context: Preceding text + - continuation: The expected continuation + + See InContextLearningDataset for more details. + """ + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(answer_key='continuation', + static_keys=['mode'], + tensor_keys=[ + 'input_ids', 'continuation_indices', 'labels', + 'attention_mask' + ], + base_batch={ + 'input_ids': [], + 'continuation_indices': [], + 'mode': 'icl_task', + 'labels': [] + }, + batch_mapping={ + 'input_ids': 'context', + 'labels': 'context' + }, + padding_side='right', + *args, + **kwargs) + + +class InContextLearningMultipleChoiceTaskDataset(InContextLearningDataset): + """A dataset that construct batches for in-context learning multiple choice. + + evaluation. + + If each question has N answer choices, we construct N distinct inputs per question. In order to ensure + consistency across multi-GPU, we set the batch size to be `min(N, batch_size)` so that all N + inputs per question can stored in the same batch. + + The default input format is a jsonl file with the following fields: + - query: The preceding text, question, or document relevant to the choices + - gold: Index of the correct choice under 'choices' + - choices: A list of strings, each being one of the potential choices + + Each batch then consists of ``|batch_size // N|`` distinct questions and has the following the structure. + - input_ids: Input tensor ``|batch x seqlen x # tokens|`` + - continuation_indices: List of ``|batch|`` consisting of tensors indicating which indices in the sequence correspond to the question answer (aka continuation) + - mode: Indicates to the model that this is an ICL task and may rely on a custom code path to properly update metrics + - labels: Identical to the input, used by the model to calculate loss/metrics + - gold_indices: List of length ``|batch_size // N|`` indicating for each question, which of the answers is correct (via an integer [0, N-1]) + - choice_groupings: Indicates which indices of the batch correspond to which questions + + Additional Args: + choices_key (str): The key under which the choices are stored in the saved dataset. Defaults to 'choices'. + """ + + def __init__(self, + choices_key: str = 'choices', + static_keys: Optional[List] = None, + list_of_tensors_keys: Optional[List] = None, + list_of_tuples_keys: Optional[List] = None, + list_of_primitives: Optional[List] = None, + *args: Any, + **kwargs: Any): + self.choices_key = choices_key + base_batch = { + 'input_ids': [], + 'continuation_indices': [], + 'mode': 'icl_task', + 'labels': [], + 'gold_indices': [], + 'choice_groupings': [], + } + context_key = kwargs.pop('context_key', 'query') + static_keys = kwargs.pop('static_keys', ['mode', 'generation_kwargs']) + tensor_keys = kwargs.pop('tensor_keys', + ['input_ids', 'labels', 'attention_mask']) + self.list_of_tensors_keys = list_of_tensors_keys or [ + 'continuation_indices' + ] + self.list_of_tuples_keys = list_of_tuples_keys or ['choice_groupings'] + self.list_of_primitives = list_of_primitives or ['gold_indices'] + super().__init__(context_key=context_key, + base_batch=base_batch, + static_keys=static_keys, + tensor_keys=tensor_keys, + padding_side='right', + *args, + **kwargs) + self.num_choices = len(self.dataset[0][self.choices_key]) + self.batch_mapping_per_choice = { + 'input_ids': 'context', + 'labels': 'context' + } + self.batch_map_per_example = {'gold_indices': 'gold'} + + def get_answer_from_example(self, + example: Dict, + in_context: bool = False) -> str: + """Returns the correct answer from the example's choices. + + Args: + example (Dict): The example from which to retrieve the answer + + Returns: + str: The full string of the correct answer based on the 'gold' key + """ + choices = example[self.choices_key] + gold_idx = example['gold'] + return choices[gold_idx] + + def tokenize_example(self, prompt_and_fewshot: str, ctxt: str, + example: Dict) -> Dict[str, Any]: + """Runs text through the tokenizer and handle special cases. + + Args: + prompt_and_fewshot (str): The collection of the prompt and fewshot examples that belongs before the example's context + ctx (str): The specific example's derrived context + example (Dict): The example as a dictionary. + + Returns: + Dict: Dictionary with the tokenized data + """ + # NOTE: some of this is repeated from super class but for loop makes things considerably different + tokenized_example = {} + # Always add special tokens to preamble + preamble = self.tokenizer(prompt_and_fewshot)['input_ids'] + assert isinstance(preamble, list) + preamble = self._fix_eos_on_preamble(preamble) + if self.strip_data: + # rstrip context because a prompt ending in a space results in degenerate output + ctxt = ctxt.rstrip() + # Never add special tokens to context + tokenized_context = self.tokenizer( + ctxt, add_special_tokens=False)['input_ids'] + assert isinstance(tokenized_context, list) + tokenized_context = preamble + tokenized_context + + tokenized_example[self.context_key] = [] + tokenized_example[self.answer_key] = [] + tokenized_example['continuation_indices'] = [] + # NOTE: Treating tokenize_labels as True for all MC datasets (required for our MC accuracy metric) + for choice in example[self.choices_key]: + if self.prefix_space: + choice = f' {choice}' if not choice.startswith(' ') else choice + + # Never add special tokens to answer + tokenized_answer = self.tokenizer( + choice, add_special_tokens=False)['input_ids'] + assert isinstance(tokenized_context, list) + assert isinstance(tokenized_answer, list) + trimmed_context = trim_context(tokenized_context, tokenized_answer, + self.padding_size) + assert isinstance(trimmed_context, list) + continuation_indices = get_continuation_span( + trimmed_context, tokenized_answer) + padded_context = make_padded_input( + trimmed_context, + tokenized_answer, + self.padding_size, + self.pad_tok_id, + self.padding_side, + ) + + tokenized_example[self.context_key].append(padded_context) + tokenized_example[self.answer_key].append(tokenized_answer) + tokenized_example['continuation_indices'].append( + continuation_indices) + + tokenized_example['gold'] = example['gold'] + return tokenized_example + + def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """The function that the dataloader uses to accumulate data into. + + batches. We run each distinct query + answer choice through the model + separately and determine which answer has the lowest per-token- + perplexity. + + If each question has N possible choices, all N must be grouped together as distinct elements of the batch + since the batch may consist of multiple questions, the choice_groupings indicates + which contiguous sequences of elements in the batch correspond to which question + gold_indices indicates which of the [0, N-1] choices is the correct one for each question. + Args: + data (List): List of tokenized datapoints (dicts returned by self._tokenize_example) + + Returns: + Dict: Dictionary for a single batch + """ + batch = copy.deepcopy(self.base_batch) + for data_pair in data: + choice_start_idx = len(batch['continuation_indices']) + # NOTE: not using batch_mapping + for i, context_enc in enumerate(data_pair[self.context_key]): + batch['input_ids'].append(context_enc) + batch['continuation_indices'].append( + data_pair['continuation_indices'][i]) + batch['labels'].append(context_enc) + + batch['gold_indices'].append(data_pair['gold']) + choice_end_idx = len(batch['continuation_indices']) + batch['choice_groupings'].append((choice_start_idx, choice_end_idx)) + + batch = convert_tokens_to_tensors(batch, self.tokenize_labels) + batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id) + return batch + + def get_num_samples_in_batch(self, batch: Dict[str, torch.Tensor]) -> int: + return batch['input_ids'].shape[0] // self.num_choices + + def split_batch(self, batch: Any, + microbatch_size: Union[int, float]) -> Sequence[Any]: + """Split batch while ensuring all continuations are in the same. + + microbatch. + + In ICL Multiple Choice, we duplicate each data point for each possible continuation. + When splitting a batch, we have logical example, which refer to one possible question, + and real example, which refers to one possible continuation. As example count and + microbatch_size are tracked in logical example, we split logical attributes by + microbatch_size and real attributes by microbatch_size * num_choices. + Args: + batch (Dict): Batch of data + microbatch_size (int | float): Size of microbatches + + Returns: + list: List of chunked batches + """ + if isinstance(microbatch_size, float): + raise ValueError( + 'split_batch does not support floating point microbatch_size.') + chunked = {} + for k, v in batch.items(): + if k in self.static_keys: + # Defer broadcasting primitives until we know num_chunks + pass + elif type(v) == list: + # list of tensors - 'continuation_indices' + if k in self.list_of_tensors_keys: + chunked[k] = _split_list(v, + microbatch_size * self.num_choices) + # list of tuples - 'choice_groupings' + elif k in self.list_of_tuples_keys: + chunked[k] = _split_list(v, microbatch_size) + # list - 'gold_indices' + elif k in self.list_of_primitives: + chunked[k] = _default_split_batch(v, microbatch_size) + else: + raise ValueError(f'Unexpected key {k} in list splitting') + elif k in self.tensor_keys: + chunked[k] = _default_split_batch( + v, microbatch_size * self.num_choices) + else: + raise ValueError(f'Unexpected key {k} in batch splitting') + num_chunks = len(chunked['input_ids']) + # Broadcast primitives to all chunks + for k, v in batch.items(): + if k in self.static_keys: + chunked[k] = [v] * num_chunks + + return [ + {k: v[idx] for k, v in chunked.items()} for idx in range(num_chunks) + ] + + +class InContextLearningSchemaTaskDataset( + InContextLearningMultipleChoiceTaskDataset): + """A dataset that constructs batches for in-context learning schema. + + evaluation. A schema task involves sentences with a fill-in-the-blank where + the user needs to choose the correct word to fill in from a set of N + options. We use the partial evaluation technique from + https://arxiv.org/abs/1806.02847 to determine the model's choice of fill-in + word. + + The default input format is a jsonl file with the following fields: + - context_options: List of strings corresponding to possible preceding context options for the continuation + - gold: Index of the correct context from 'context_options' + - continuation: The finishing continuation + + Each batch then consists of ``batch_size // N`` distinct tasks and has the following the structure + - input_ids: Input tensor ``batch x seqlen x # of tokens`` + - continuation_indices: List of ``batch`` consisting of tensors indicating which indices in the sequence correspond to the question answer (aka continuation) + - mode: Indicates to the model that this is an ICL task and may rely on a custom code path to properly update metrics + - labels: Identical to the input, used by the model to calculate loss/metrics + - gold_indices: List of length ``batch_size // N`` indicating for each question, which of the answers is correct (via an integer [0, N-1]) + - choice_groupings: Indicates which indices of the batch correspond to which questions + """ + + def __init__(self, + choices_key: str = 'context_options', + *args: Any, + **kwargs: Any): + static_keys = ['mode'] + tensor_keys = ['input_ids', 'labels', 'attention_mask'] + list_of_tensors_keys = ['continuation_indices'] + super().__init__(choices_key=choices_key, + context_key=choices_key, + static_keys=static_keys, + tensor_keys=tensor_keys, + list_of_tensors_keys=list_of_tensors_keys, + *args, + **kwargs) + self.base_batch = { + 'input_ids': [], + 'continuation_indices': [], + 'mode': 'icl_task', + 'labels': [], + 'gold_indices': [], + 'choice_groupings': [], + } + + def construct_context(self, + example: Dict[str, Any], + preceding_text: str = '', + add_answer: bool = False) -> str: + """Takes a example and constructs a context with the correct context. + + for. + + the example's continuation. + + Args: + example (Dict): The example from which to construct the context + preceding_text (str): Any preceding text, needed to if self.example_delimiter is needed at the beginning + add_answer (bool): This will always be true when calling this function for SchemaTaskDataset + + Returns: + str: The single correct context for a given continuation + """ + context_options = example[self.choices_key] + gold_idx = example['gold'] + continuation = example['continuation'] + context = context_options[gold_idx] + if len(preceding_text) > 0: + context = f'{self.example_delimiter}{context}' + context = f'{self.prelimiter}{context}{self.continuation_delimiter}{continuation}' + return context + + def _construct_multiple_contexts(self, + example: Dict, + preceding_text: str = '') -> List[str]: + """Takes a example and constructs all contexts. + + Optionally, appends this to preceeding text (such as a prompt or fewshot examples). + + Args: + example (Dict): The example from which to construct the context + preceding_text (str): Any preceding text, needed to if self.example_delimiter is needed at the beginning + + Returns: + list: All context options for the selected example with formatting + """ + context_options = example[self.choices_key] + if len(preceding_text) > 0: + if self.strip_data: + cont_del = self.continuation_delimiter.rstrip() + else: + cont_del = self.continuation_delimiter + context_options = [ + f'{self.prelimiter}{self.example_delimiter}{c}{cont_del}' + for c in context_options + ] + else: + context_options = [f'{self.prelimiter}{c}' for c in context_options] + return context_options + + def _prep_example( + self, + example: Dict, + example_idx: int, + num_fewshot: int, + prompt_string: str, + fewshot_rng: random.Random, + ) -> Dict[str, Any]: + """Prepares a single example from a HF Dataset into tokenized format. + + with prompt and fewshot examples. + + Each task consists of multiple contexts and a single, correct continuation. Will preprend fewshot examples and + prompt if present. + + Args: + example (Dict): A dictionary from the hf dataset + example_idx (int): The index of example + num_fewshot (int): Number of examples context/continuation pairs to prepend to the test pair + prompt_string (str): The prompt to prepend to all inputs + fewshot_rng (random.Random): Random number generator to use for fewshot sampling + + Returns: + Dict: Contains a dictionary with the tokenized data + """ + prompt_and_fewshot = self._generate_few_shot_prompt( + num_fewshot, example_idx, prompt_string, fewshot_rng) + ctxt = self._construct_multiple_contexts(example, prompt_and_fewshot) + tokenized_example = self.tokenize_example(prompt_and_fewshot, ctxt, + example) + return tokenized_example + + def tokenize_example(self, prompt_and_fewshot: str, + context_options: List[str], + example: Dict) -> Dict[str, Any]: + """Runs text through the tokenizer and handle special cases. + + Args: + prompt_and_fewshot (str): The collection of the prompt and fewshot examples that belongs before the example's context + ctx (str): The specific example's derrived context + example (Dict): The example as a dictionary. + + Returns: + Dict: Dictionary with the tokenized data + """ + tokenized_example = {} + preamble = self.tokenizer(prompt_and_fewshot)['input_ids'] + assert isinstance(preamble, list) + preamble = self._fix_eos_on_preamble(preamble) + encoded_contexts = [ + preamble + + # pyright: ignore[reportOperatorIssue, reportGeneralTypeIssues] + self.tokenizer(c, add_special_tokens=False)[ + 'input_ids'] # pyright: ignore[reportOperatorIssue, ] + for c in context_options + ] + continuation = example['continuation'] + if self.prefix_space: + continuation = (f' {continuation}' if + not continuation.startswith(' ') else continuation) + tokenized_continuation = self.tokenizer( + continuation, add_special_tokens=False)['input_ids'] + + tokenized_example[self.context_key] = [] + tokenized_example['continuation_indices'] = [] + tokenized_example[self.answer_key] = [] + for context in encoded_contexts: + assert isinstance(context, list) + assert isinstance(tokenized_continuation, list) + trimmed_context = trim_context(context, tokenized_continuation, + self.padding_size) + assert isinstance(trimmed_context, list) + continuation_indices = get_continuation_span( + trimmed_context, tokenized_continuation) + padded_context = make_padded_input(trimmed_context, + tokenized_continuation, + self.padding_size, + self.pad_tok_id, + self.padding_side) + tokenized_example[self.context_key].append(padded_context) + tokenized_example['continuation_indices'].append( + continuation_indices) + tokenized_example[self.answer_key].append(tokenized_continuation) + + tokenized_example['gold'] = example['gold'] + return tokenized_example + + +class InContextLearningCodeEvalDataset(InContextLearningDataset): + """A dataset that constructs batches for in-context learning code. + + evaluation. + + The input format is expected to be a jsonl file with the following fields: + + - task_id: Label of given task + - prompt: The code snippet that must be completed + - entry_point: The entry to the function/code snippet to generate + - canonical_solution: Working solution + - test: The checker code that will run to completion if the code generation is valid and otherwise throw assertion + - test_inputs: List of test inputs + - test_outputs: List of test outputs + - language: The language of the code snippet + + Each batch then consists of the following the structure + + - input_ids: Input tensor batch x seqlen x num tokens + - mode: Indicates to the model that this is an ICL task and may rely on a custom code path to properly update metrics + - mode: Always set to 'generate' + - labels: Exact solution for the coding problem + - prompts: Prompt for the task + - entry_points: List of entry points + - test_inputs: List of test inputs + - test_outputs: List of test outputs + - languages: List of languages + - pass_at_k: Passed value for pass_at_k + - generation_kwargs: Dictionary of kwargs neeeded for generation. Includes the following, which will be individually overwritten + by keys in generaiton_kwargs if set (see https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig + for more details): + + - pad_token_id: ID for padding token, derived automatically + - num_beams: How many beams to search for generations, default set to 1 + - do_sample: Determines whether model is sampling or greedily decoding. Always set to True + - use_cache: Whether or not to use past key values to speed up sampling. Always set to True + + Additional Args: + generations_per_sample (int) (defaults to 1): The number of independently computed returned sequences for each element in the batch + pass_at_k (int) (defaults to 1): k for how many chances the model gets to write passing code + """ + + def __init__( + self, + generations_per_sample: int, + pass_at_k: Union[int, list[int]] = 1, + *args: Any, + **kwargs: Any, + ): + if isinstance(pass_at_k, int): + pass_at_k = [pass_at_k] + if generations_per_sample < max(pass_at_k): + raise ValueError( + f'generations_per_sample ({generations_per_sample}) must be greater than or equal to pass_at_k ({pass_at_k}) for code evaluation.' + ) + batch_mapping = { + 'input_ids': 'prompt', + 'prompts': 'prompt_text', + 'tests': 'test', + 'labels': 'canonical_solution', + 'entry_points': 'entry_point', + 'test_inputs': 'test_inputs', + 'test_outputs': 'test_outputs', + 'languages': 'language', + 'sample_id': 'sample_id', + } + # Linting complains if these are not set in init + self.max_prompt_length = 0 + self.max_answer_length = 0 + static_keys = [ + 'mode', + 'pass_at_k', + 'generation_kwargs', + 'generations_per_sample', + 'dataset_size', + ] + list_keys = [ + 'prompts', + 'tests', + 'entry_points', + 'test_inputs', + 'test_outputs', + 'languages', + 'labels', + 'sample_id', + ] + tensor_keys = ['input_ids', 'attention_mask'] + super().__init__( + context_key='prompt', + answer_key='canonical_solution', + strip_dataset=False, + static_keys=static_keys, + list_keys=list_keys, + tensor_keys=tensor_keys, + tokenize_labels=False, + padding_side='left', + batch_mapping=batch_mapping, + *args, + **kwargs, + ) + self._set_max_prompt_and_answer_lengths() + if self.max_seq_len < self.max_prompt_length: + log.warning(f'`max_seq_len` {self.max_seq_len} was less than `max_prompt_len`: {self.max_prompt_length}' \ + + ' setting `max_seq_len`=`max_prompt_len`') + self.max_seq_len = self.max_prompt_length + dataset_size = len(self.dataset) + self.dataset = self.dataset.map(self._trim_padding) + self.dataset = self.repeat_dataset(self.dataset, generations_per_sample) + + if self.max_answer_length < self.max_seq_len - self.max_prompt_length: + max_new_tokens = self.max_answer_length + else: + max_new_tokens = self.max_seq_len - self.max_prompt_length + + self.base_batch = { + 'input_ids': [], + 'mode': 'generate', + 'labels': [], + 'prompts': [], + 'tests': [], + 'entry_points': [], + 'test_inputs': [], + 'test_outputs': [], + 'languages': [], + 'pass_at_k': pass_at_k, + 'generation_kwargs': { + 'pad_token_id': self.pad_tok_id, + 'num_beams': 1, # single beam + 'do_sample': True, + 'temperature': 0.2, # good default for code + 'use_cache': True, + 'eos_token_id': self.tokenizer.eos_token_id, + 'max_new_tokens': max(max_new_tokens, 1) + }, + 'sample_id': [], + 'pass_at_k': list(pass_at_k), + 'generations_per_sample': generations_per_sample, + 'dataset_size': dataset_size, + } + if 'generation_kwargs' in kwargs: + self.update_generation_kwargs(kwargs['generation_kwargs']) + + def repeat_dataset(self, dataset: HFDataset, repetitions: int) -> HFDataset: + + def _repeat_dataset(): + for i, sample in enumerate(dataset): + for _ in range(repetitions): + assert isinstance(sample, dict) + yield {'sample_id': i, **sample} + + from datasets import \ + Dataset as HFDataset # pyright: ignore[reportGeneralTypeIssues] + + repeated_dataset = HFDataset.from_generator(_repeat_dataset) + assert isinstance(repeated_dataset, HFDataset) + return repeated_dataset + + def _set_max_prompt_and_answer_lengths(self): + """Iterates through the dataset and finds the maximum prompt length and. + + sequence lengths. + + Returns: + None + """ + max_prompt_length = 0 + max_answer_length = 0 + for example in self.dataset: + assert isinstance(example, Dict) + unpadded_example = [ + token for token in example[self.context_key] + if token != self.pad_tok_id + ] + max_prompt_length = max(max_prompt_length, len(unpadded_example)) + + tokenized_answer = self.tokenizer( + example['canonical_solution'], + add_special_tokens=False)['input_ids'] + assert isinstance(tokenized_answer, list) + len_tokenized_answer = len(tokenized_answer) + max_answer_length = max(max_answer_length, len_tokenized_answer) + + self.max_prompt_length = max_prompt_length + self.max_answer_length = max_answer_length + _MAX_ANSWER_BUFFER_LENGTH + + def _trim_padding(self, example: Dict): + """Adjusts padding to the maximum prompt length rather than max_seq_len. + + Needs to be done after the dataset has been processed because we don't + know the maximum prompt length until after we've tokenized it. + + Returns: + dataset: A HuggingFace Dataset with different padding lengths for example[self.context_key] + """ + # Remove padding tokens applied during tokenization + unpadded_prompt = [ + token for token in example[self.context_key] + if token != self.pad_tok_id + ] + # Reapply padding only to max_prompt_length + full_prompt = trim_context(unpadded_prompt, [], self.max_prompt_length) + padded_context = make_padded_input(full_prompt, [], + self.max_prompt_length, + self.pad_tok_id, self.padding_side) + + example[self.context_key] = padded_context + return example + + def tokenize_example(self, prompt_and_fewshot: str, ctxt: str, + example: Dict) -> Dict[str, Any]: + """Adds extra code task details to the example dictionary. + + See InContextLearningDataset for more details + """ + tokenized_example = super().tokenize_example(prompt_and_fewshot, ctxt, + example) + tokenized_example['prompt_text'] = example['prompt'] + tokenized_example['task_id'] = example['task_id'] + tokenized_example['canonical_solution'] = example['canonical_solution'] + tokenized_example['test'] = example['test'] + tokenized_example['entry_point'] = example['entry_point'] + tokenized_example['test_inputs'] = example['test_inputs'] + tokenized_example['test_outputs'] = example['test_outputs'] + tokenized_example['language'] = example['language'] + return tokenized_example + + +def build_icl_dataloader( + icl_task_type: str, + dataset_uri: str, + tokenizer: transformers.PreTrainedTokenizerBase, + batch_size: int, + max_seq_len: int, + pad_tok_id: int, + num_fewshot: int, + prompt_string: str, # e.g. 'translate english to french:' + example_delimiter: str, # e.g. '\n' + continuation_delimiter: str, # e.g. '' + hf_loading_vars: Dict, + hf_parsing_map: Dict, + destination_path: str, + prelimiter: str, # e.g. 'Question: ' + cot_delimiter: str, # e.g. ' ### ' + fewshot_random_seed: int, + pass_at_k: int, + generations_per_sample: int, + generation_kwargs: Dict, + early_stopping_criteria: Optional[List[str]] = None, + do_normalization: bool = True) -> DataSpec: + """Factory method that builds the specific dataset for the specified. + + icl_task_type. See documentation for `get_icl_task_dataloader` for arugment + documentation. + + When writing a dataset for a new task, here you will need to: + 1. add the dataset to the factory and choose an appropriate string + 2. set the batch size for that task (see InContextLearningMultipleChoiceTaskDataset for why + this might be different) + 3. set the `split_batch` funciton if necessary + """ + if icl_task_type == 'multiple_choice': + dataset = InContextLearningMultipleChoiceTaskDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=prelimiter, + fewshot_random_seed=fewshot_random_seed, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=generation_kwargs, + ) + batch_size = max(dataset.num_choices, batch_size) + effective_batchsize = batch_size // dataset.num_choices + elif icl_task_type == 'schema': + dataset = InContextLearningSchemaTaskDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=prelimiter, + fewshot_random_seed=fewshot_random_seed, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=generation_kwargs, + ) + batch_size = max(dataset.num_choices, batch_size) + effective_batchsize = batch_size // dataset.num_choices + elif icl_task_type == 'language_modeling': + dataset = InContextLearningLMTaskDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=prelimiter, + fewshot_random_seed=fewshot_random_seed, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=generation_kwargs, + ) + effective_batchsize = batch_size + elif icl_task_type == 'generation_task_with_answers' or icl_task_type == 'question_answering': + if icl_task_type == 'question_answering': + warnings.warn( + VersionedDeprecationWarning( + "ICL task type 'question_answering' is now deprecated. Use identifier 'generation_task_with_answers'", + 'v0.7.0')) + dataset = InContextLearningGenerationTaskWithAnswersDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=prelimiter, + fewshot_random_seed=fewshot_random_seed, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + cot_delimiter=cot_delimiter, + early_stopping_criteria=early_stopping_criteria, + do_normalization=do_normalization, + generation_kwargs=generation_kwargs, + ) + effective_batchsize = batch_size + elif icl_task_type == 'code_evaluation': + warnings.warn( + VersionedDeprecationWarning( + "ICL task type 'code_evaluation' is deprecated and will no longer be supported. ", + 'v0.7.0')) + dataset = InContextLearningCodeEvalDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=prelimiter, + fewshot_random_seed=fewshot_random_seed, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + pass_at_k=pass_at_k, + generations_per_sample=generations_per_sample, + generation_kwargs=generation_kwargs, + ) + effective_batchsize = batch_size + else: + raise Exception(f'Unrecognized ICL task type: {icl_task_type}') + + sampler = dist.get_sampler(dataset, drop_last=False, shuffle=False) + + split_batch = None + if isinstance( + dataset, + ( + InContextLearningMultipleChoiceTaskDataset, + InContextLearningGenerationTaskWithAnswersDataset, + InContextLearningCodeEvalDataset, + ), + ): + split_batch = dataset.split_batch + + return DataSpec( + DataLoader( + dataset, + batch_size=effective_batchsize, + sampler=sampler, + collate_fn=dataset.collate_fn, + ), + device_transforms=None, + get_num_samples_in_batch=dataset.get_num_samples_in_batch, + split_batch=split_batch, + ) + + +def partition_dataset_by_category(dataset_uri: str, destination_path: str, + hf_loading_vars: Dict, + hf_parsing_map: Dict) -> Dict[str, str]: + """If has_categories is enabled, we partition the dataset into a separate. + + dataset for each category value in the data and write each partition to a + local file. + + Args: + dataset_uri (str): Location of dataset. + destination_path (str): Base destination path, we will write a separate partition off this URI for each category. + + Raises: + MissingConditionalImportError: If datasets not installed raise exception. + Exception: If 'category' key missing from dataset, raise exception. + + Returns: + Dict[str, str]: Mapping of category names to partitioned dataset local files names. + """ + if dataset_uri.startswith('hf://'): + dataset_uri = dataset_uri.replace('hf://', '') + dataset = load_dataset(dataset_uri, **hf_loading_vars) + assert isinstance(dataset, HFDataset) or isinstance( + dataset, IterableDataset) + if hf_parsing_map: + dataset_parsing_func = lambda example: { + k: ' '.join([str(example[col]) for col in v]) + for k, v in hf_parsing_map.items() + } + assert hasattr(dataset, 'column_names') + dataset = dataset.map(dataset_parsing_func, + remove_columns=dataset.column_names) + else: + with dist.local_rank_zero_download_and_wait(destination_path): + if dist.get_local_rank() == 0: + get_file(dataset_uri, destination_path, overwrite=True) + dataset = load_dataset('json', + data_files=destination_path, + split='train', + streaming=False) + assert isinstance(dataset, HFDataset) or isinstance(dataset, + IterableDataset) + assert hasattr(dataset, 'features') + assert dataset.features is not None + if 'category' not in dataset.features.keys(): + raise Exception(f"""Attempted to partition dataset by `category` \ + but it doesn't have a `category` key. \ + Got keys: {str(list(dataset.features.keys()))}""") + categories = sorted( + set(dataset['category'] + )) # pyright: ignore[reportIndexIssue, reportGeneralTypeIssues] + output_files = {} + for cat in categories: + path = destination_path.split('/') + cat_dest = '/'.join(path[:-1]) + f'/{cat}_{path[-1]}' + tmp_path_to_broadcast = str(os.path.abspath(cat_dest)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + if dist.get_local_rank() == 0: + subset = [ + l for l in dataset if + l['category'] == cat # pyright: ignore[reportGeneralTypeIssues] + ] # pyright: ignore[reportArgumentType, reportCallIssue] + with open(gathered_paths[0], 'w', encoding='utf8') as f: + for l in subset: + f.write(json.dumps(l, ensure_ascii=False) + '\n') + output_files[cat] = cat_dest + return output_files + + +def get_icl_task_dataloader( + icl_task_type: str, + dataset_uri: str, + tokenizer: Union[transformers.PreTrainedTokenizer, + transformers.PreTrainedTokenizerFast], + batch_size: int, + max_seq_len: int, + pad_tok_id: int, + num_fewshot: int, + prompt_string: str, # e.g. 'translate english to french:' + example_delimiter: str, # e.g. '\n' + continuation_delimiter: str = '', + destination_path: str = '', + question_prelimiter: str = '', # e.g. 'Question: ' + fewshot_random_seed: int = 1234, + pass_at_k: int = 1, + generations_per_sample: int = 1, + cot_delimiter: str = '', + has_categories: bool = False, + hf_loading_vars: Optional[Dict] = None, + hf_parsing_map: Optional[Dict] = None, + generation_kwargs: Optional[Dict] = None, + early_stopping_criteria: Optional[List[str]] = None, + do_normalization: bool = True) -> Union[DataSpec, Dict[str, DataSpec]]: + r"""Constructs a dataloader (or dataloaders if has_categories is True) + + capable of evaluating LLMs on in-context learning language modeling tasks, + for example LAMBADA. An example usage is below: + + .. testsetup:: + + import transformers + from composer.models import HuggingFaceModel + from composer.trainer import Trainer + dataset_uri = "/tmp/dataset_uri.jsonl" + dataset = RandomTextClassificationDataset(size=16, use_keys=True) + train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=8) + hf_model, tokenizer = HuggingFaceModel.hf_from_composer_checkpoint('composer-hf-checkpoint.pt') + # At this point, hf_model is randomly initialized + composer_model = HuggingFaceModel(hf_model, hf_tokenizer) + + Example: + + .. testcode:: + + + dl = get_icl_task_dataloader( + 'language_modeling', + dataset_uri, + tokenizer, + batch_size=2, + max_seq_len=2048, + pad_tok_id=tokenizer.pad_token_id, + num_fewshot=10, + prompt_string='translate english to french', + example_delimiter='\\n', + continuation_delimiter='' + ) + eval_evaluator = Evaluator( + label="lambada", + dataloader=dl, + metric_names=['InContextLearningLMAccuracy'] + ) + trainer = Trainer( + model=model, + train_dataloader=train_dataloader, + eval_dataloader=eval_evaluator, + optimizers=optimizer, + max_duration="1ep", + ) + + Args: + icl_task_type (str): Name of icl_task type. One of ['multiple_choice', 'schema', 'language_modeling', 'generation_task_with_answers', 'code_evaluation'] + dataset_uri (str): A local path, a remote path beginning with ``s3://`` or another backend, or a HuggingFace dataset uri prepended with ``hf://``. + Alternate backends must be supported by :meth:`composer.utils.maybe_create_object_store_from_uri`. + A local dataset must consist of rows of JSON data points with task dependant fields. + The default keys expected are "context" and "answer". + tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer used to map between strings and token ids. + batch_size (int): Size of a batch used for eval + max_seq_len (int): The maximum sequence length supported by the model. + pad_tok_id (int): The special token used for padding batches. + num_fewshot (int): The number of complete fewshot examples to prepend before each test example. These are not identical across examples. + prompt_string (str, default = ''): Prompt string to put once before all fewshot examples/test examples (e.g. 'Translate english to french.'). + example_delimiter (str, default = '\\n'): Separator inserted before (context, answer) pairs (e.g. '\\n') for fewshot sampling and prompting. + continuation_delimiter: (str, default = ' '): Separator inserted between context and answer in each example (e.g. '\\nA: '). + destination_path: (str, default = ''): This is the local file where remote datasets will be saved. + question_prelimiter: (str, default = ''): Text to be prepended before each context, including few shot examples (e.g. "Question: "). + fewshot_random_seed (int, default = 1234): Random seed to use for fewshot sampling + pass_at_k (int): k for how many chances the model gets to write passing code. + generations_per_sample (int): How many outputs to generate per prompt. Passed in generation_kwargs under "num_return_sequences" and overwritten by generation_kwargs dict. + cot_delimiter (str): Delimiter to place between chain of thoughts and continuations. + has_categories: (bool): If ``True``, we will search the dataset file for a category key, and partition the dataset into a separate dataloader for each category occurring in the data. + hf_loading_vars (Dict, default = None): A dictionary containing keyword arguments to be passed into `load_dataset` if dataset is being pulled from HF. + hf_parsing_map (Dict, default = None): A dictionary containing a mapping from HF columns to ICL dataset keys. The dictionary should be formatted {icl_key:[hf_key1, hf_key1]}. + Column contents will be concatenated with ' ' seperating them. If not included, will load the columns already present in the HF dataset. + generation_kwargs (Dict, default = None): A dictionary containing keyword arguments to be passed along to the model's generate function. Overwrites any previously specified generation + keyword args in this fucntion (see https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig + for more details) + early_stopping (List, default = None): A list of strings that, when found in a model's output, will be treated as a stopping criteria at metric computation time. + Used in generation tasks with CoT + do_normalization (bool, default = True): Whether or not to normalize the outputs and labels in InContextLearningGenerationTaskWithAnswersDataset. Only used in generation tasks. + + Returns: + DataLoader: A dataloader used for performing in-context learning evaluation on the dataset provided. + """ + if hf_loading_vars is None: + hf_loading_vars = {} + if hf_parsing_map is None: + hf_parsing_map = {} + if generation_kwargs is None: + generation_kwargs = {} + if early_stopping_criteria is None: + early_stopping_criteria = [] + + if has_categories: + result_dls = {} + output_files = partition_dataset_by_category(dataset_uri, + destination_path, + hf_loading_vars, + hf_parsing_map) + categories = sorted(output_files.keys()) + for category in categories: + partition_uri = output_files[category] + result_dls[category] = build_icl_dataloader( + icl_task_type=icl_task_type, + dataset_uri=partition_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=partition_uri + '_tmp', + prelimiter=question_prelimiter, + cot_delimiter=cot_delimiter, + fewshot_random_seed=fewshot_random_seed, + pass_at_k=pass_at_k, + generations_per_sample=generations_per_sample, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=generation_kwargs, + early_stopping_criteria=early_stopping_criteria, + do_normalization=do_normalization, + ) + return result_dls + else: + return build_icl_dataloader( + icl_task_type=icl_task_type, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=question_prelimiter, + cot_delimiter=cot_delimiter, + fewshot_random_seed=fewshot_random_seed, + pass_at_k=pass_at_k, + generations_per_sample=generations_per_sample, + generation_kwargs=generation_kwargs, + early_stopping_criteria=early_stopping_criteria, + do_normalization=do_normalization, + ) diff --git a/llmfoundry/eval/datasets/utils.py b/llmfoundry/eval/datasets/utils.py new file mode 100644 index 0000000000..7ea7f9fae2 --- /dev/null +++ b/llmfoundry/eval/datasets/utils.py @@ -0,0 +1,277 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Utility and helper functions for datasets.""" +from __future__ import annotations + +import logging +import random +from typing import Any, Dict, List, Optional, Set + +import torch +import transformers + +__all__ = [ + 'MultiTokenEOSCriteria', +] + +log = logging.getLogger(__name__) + + +def strip_data(example: Dict) -> Dict: + """Remove white space from the begging and end of string values in a. + + dictionary. + + Args: + example: Dictionary to be stripped + + Returns: + dict: The same dictionary with .strip() applied to any value in the dict that is a string + """ + return { + k: v.strip() if isinstance(v, str) else v for k, v in example.items() + } + + +def tokenizer_needs_prefix_space( + tokenizer: transformers.PreTrainedTokenizerBase) -> bool: + """Test for whether a prefix space is needed before the continuation. + + Sentencepiece tokenization should not have a prefix space, but gpt2 style + BPE should. + + Args: + tokenizer: Tokenizer to test + + Returns: + bool: Whether or not the tokenizer needs a prefix space + """ + test_tokens = tokenizer(' a', add_special_tokens=False)['input_ids'] + assert isinstance(test_tokens, list) + return len(test_tokens) == 1 + + +def trim_context(context_enc: List, continuation_enc: List, + max_seq_len: int) -> List: + """Trims a list of tokens down to `max_seq_len` if the length of the list. + + plus the continuation is more than `max_seq_len`. It will always trim tokens + from the left, i.e. tokens at the beginning of the context will be removed. + + Args: + context_enc (list): List of tokens in the context + continuation_enc (lsit): List of tokens in the continuation + max_seq_len (int): Maximum length the model can ingest + + Returns: + list: The encoded context trimmed from the left + """ + if len(continuation_enc) + len(context_enc) > max_seq_len: + context_max_subseq_len = max_seq_len - len(continuation_enc) + + if context_max_subseq_len < 0: + # can't support continuations which are longer than the max seq len + raise Exception( + f'Dataset included continuation longer than the max seq len') + + # clip from the end + context_enc = context_enc[-(context_max_subseq_len):] + return context_enc + + +def get_continuation_span(context_enc: List, + continuation_enc: List) -> torch.Tensor: + """Gets the list of indices of the continuation tokens for language. + + modeling. + + or generation tasks. + + Args: + context_enc (list): List of context tokens + continuation_enc (list): List of continuation tokens + + Returns: + torch.tensor: A tensor containing indices corresponding to continuation tokens + """ + return torch.tensor( + range(len(context_enc), + len(context_enc) + len(continuation_enc))) + + +def make_padded_input(context_enc: List, + continuation_enc: List, + max_seq_len: int, + pad_tok_id: int, + padding_side: str = 'right') -> torch.Tensor: + """Takes an encoded context and continuation and clips the beginning of the. + + context if they're too long. Adds the padding token to the specified side. + + Args: + context_enc (List): The encoded input to the model + continuation_enc (List): The encoded desired output for the example + max_seq_list (int): Maximum length sequences can be + pad_tok_id (int): The token id we pad with + padding_side (str): Which side to pad the context on. Can be 'right' or 'left + + Returns: + input (torch.tensor): The padded and encoded context + continuation_span (torch.tensor): The _inclusive_ range of indices corresponding to the continuation + """ + inp = torch.tensor( + (context_enc + continuation_enc), + dtype=torch.long, + ) + (inp_len,) = inp.shape + + # Sometimes tokenizers that have neither a pad_tok_id or eos_tok_id will pass None in as the padding + # token and cause errors + if not isinstance(pad_tok_id, int): + raise ValueError( + f'`pad_tok_id` must be an integer. Found {type(pad_tok_id)} instead' + ) + # pad length from seq to padding_length + if padding_side == 'right': + inp = torch.cat( + [ + inp, # [seq] + torch.LongTensor((max_seq_len - inp_len) * [pad_tok_id]), + ], + dim=0, + ) + elif padding_side == 'left': + inp = torch.cat( + [ + torch.LongTensor((max_seq_len - inp_len) * [pad_tok_id]), + inp, # [seq] + ], + dim=0, + ) + else: + raise ValueError( + f"Unknown padding_side {padding_side}. padding_side must be either 'left' or 'right'" + ) + + return inp + + +def convert_tokens_to_tensors(batch: Dict, + tokenize_labels: bool) -> Dict[str, Any]: + """HF Datasets converts tensors into lists when we store them, and we don't. + + want to use `type='torch'` because some content in the dataset, like + generation args or single ints, should not be converted. + + Here, we convert those lists of tokens back into tensors in order to feed them into the model. + + Args: + batch (dict): A dictionary of batched inputs + tokenize_labels (bool): Whether or not the labels are tokenized (and need to be stacked) + + Returns: + dict: The batch with torch tensors in the corresponding keys instead of lists of lists + """ + batch['input_ids'] = torch.stack(list(map(torch.tensor, + batch['input_ids']))) + if tokenize_labels: + batch['labels'] = torch.stack(list(map(torch.tensor, batch['labels']))) + batch['continuation_indices'] = list( + map(torch.tensor, batch['continuation_indices'])) + return batch + + +def get_fewshot_sample_idxs(dataset_size: int, num_fewshot: int, + example_idx: int, rng: random.Random) -> Set[int]: + """Samples indices without replacement. If num_fewshot exceeds the number. + + of unique examples in the dataset, then we will have fewer than num_fewshot examples in context. + + Args: + dataset_size (int): Length of the dataset + num_fewshot (int): Number of examples to prepend + example_idx (int): Current example's index (excluded from fewshot choices) + rng (random.Random): RNG for repeatable sample selection + + Returns: + list: Indices of the examples chosen for fewshot selection + """ + num_fewshot = min(dataset_size - 1, num_fewshot) + fewshot_idxs = set(rng.sample(range(0, dataset_size), num_fewshot)) + + if example_idx in fewshot_idxs: + fewshot_idxs.remove(example_idx) + if len(fewshot_idxs) >= dataset_size - 1: + return fewshot_idxs + + replacement_sample = rng.choice(range(0, dataset_size)) + while replacement_sample in fewshot_idxs or replacement_sample == example_idx: + replacement_sample = rng.choice(range(0, dataset_size)) + fewshot_idxs.add(replacement_sample) + return fewshot_idxs + + +class MultiTokenEOSCriteria(transformers.StoppingCriteria): + """Criteria to stop on the specified multi-token sequence. + + Slightly modified from: https://github.com/EleutherAI/lm-evaluation-harness/blob/78545d42f2ca95c6fe0ed220d456eeb94f4485e9/lm_eval/utils.py#L614-L649 + """ + + def __init__( + self, + stop_sequence: str, + tokenizer: transformers.PreTrainedTokenizerBase, + batch_size: int, + ) -> None: + self.done_tracker = [False] * batch_size + self.stop_sequence = stop_sequence + self.stop_sequence_ids = tokenizer.encode(stop_sequence, + add_special_tokens=False) + + # sentence piece tokenizers add a superflous underline token before string-initial \n + # that throws off our calculation of the stop sequence length + # so we remove any token ids that produce empty strings + self.stop_sequence_ids = [ + id for id in self.stop_sequence_ids if tokenizer.decode(id) != '' + ] + + # we look back for 1 more token than it takes to encode our stop sequence + # because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']` + # and we don't want to mistakenly not stop a generation because our + # (string) stop sequence was output in a different tokenization + + self.stop_sequence_id_len = len(self.stop_sequence_ids) + 1 + self.tokenizer = tokenizer + + def __call__(self, + input_ids: torch.LongTensor, + scores: Optional[torch.FloatTensor] = None, + **kwargs: Dict[str, Any]) -> bool: + # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence + lookback_ids_batch = input_ids[:, :][:, -self.stop_sequence_id_len:] + lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch) + for i, done in enumerate(self.done_tracker): + if i >= len(lookback_tokens_batch): + # The last batch of a dataset may be smaller than `batch_size` + # Automatically set those indices in the done_tracker to True + # since those indices don't show up in the current batch + self.done_tracker[i] = True + break + elif not done: + self.done_tracker[ + i] = self.stop_sequence in lookback_tokens_batch[i] + return False not in self.done_tracker + + +def stop_sequences_criteria( + tokenizer: transformers.PreTrainedTokenizerBase, + stop_sequences: List[str], + batch_size: int, +) -> transformers.StoppingCriteriaList: + return transformers.StoppingCriteriaList([ + *[ + MultiTokenEOSCriteria(sequence, tokenizer, batch_size) + for sequence in stop_sequences + ], + ]) diff --git a/llmfoundry/eval/metrics/__init__.py b/llmfoundry/eval/metrics/__init__.py new file mode 100644 index 0000000000..6e70e2ece3 --- /dev/null +++ b/llmfoundry/eval/metrics/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""A collection of common torchmetrics.""" + +from llmfoundry.eval.metrics.nlp import ( + InContextLearningCodeEvalAccuracy, + InContextLearningGenerationExactMatchAccuracy, InContextLearningLMAccuracy, + InContextLearningLMExpectedCalibrationError, + InContextLearningMCExpectedCalibrationError, InContextLearningMetric, + InContextLearningMultipleChoiceAccuracy) + +__all__ = [ + 'InContextLearningLMAccuracy', + 'InContextLearningMultipleChoiceAccuracy', + 'InContextLearningGenerationExactMatchAccuracy', + 'InContextLearningMCExpectedCalibrationError', + 'InContextLearningLMExpectedCalibrationError', + 'InContextLearningMetric', + 'InContextLearningCodeEvalAccuracy', +] diff --git a/llmfoundry/eval/metrics/nlp.py b/llmfoundry/eval/metrics/nlp.py new file mode 100644 index 0000000000..55922e28d2 --- /dev/null +++ b/llmfoundry/eval/metrics/nlp.py @@ -0,0 +1,730 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""A collection of common torchmetrics for NLP tasks.""" + +import copy +import functools +import logging +import os +import re +import string +import warnings +from typing import Any, Callable, Dict, List + +import numpy as np +import torch +from composer.utils import dist +from composer.utils.eval_client import (EvalClient, LambdaEvalClient, + LocalEvalClient, + MosaicMLLambdaEvalClient) +from torch import Tensor +from torch.nn import functional as F +from torchmetrics import Metric + +log = logging.getLogger(__name__) + +__all__ = [ + 'InContextLearningMetric', + 'InContextLearningLMAccuracy', + 'InContextLearningMultipleChoiceAccuracy', + 'InContextLearningGenerationExactMatchAccuracy', + 'InContextLearningCodeEvalAccuracy', + 'InContextLearningLMExpectedCalibrationError', + 'InContextLearningMCExpectedCalibrationError', +] + + +class InContextLearningMetric(Metric): + + def __init__(self, *args, **kwargs): # pyright: ignore + super().__init__(*args, **kwargs) + self.needs_batch = True + + def _wrap_update(self, update: Callable) -> Callable: + """Overwrite default _wrap_update to return result of update(). + + Torch metrics wraps update with following wrapped_func but explicitly + does not return the value. In general, torchmetrics update() does not + return a value, but we want to in order to pass it on to + state.metric_outputs. + """ + + @functools.wraps(update) + def wrapped_func(*args: Any, **kwargs: Any) -> None: + self._computed = None + self._update_count += 1 + with torch.set_grad_enabled(self._enable_grad): + try: + update_result = update(*args, **kwargs) + except RuntimeError as err: + if 'Expected all tensors to be on' in str(err): + raise RuntimeError( + 'Encountered different devices in metric calculation (see stacktrace for details).' + \ + ' This could be due to the metric class not being on the same device as input.' + \ + f' Instead of `metric={self.__class__.__name__}(...)` try to do' + \ + f' `metric={self.__class__.__name__}(...).to(device)` where' + \ + ' device corresponds to the device of the input.', + ) from err + raise err + + if self.compute_on_cpu: + self._move_list_states_to_cpu() + return update_result + + return wrapped_func + + def update( + self, + batch: dict, + outputs: torch.Tensor, + labels: torch.Tensor, + ): + """Abstract interface for computing an in-context learning metrics. + + The `output_logits` argument is deprecated and will be removed in v0.21 while it's functionality will + be moved to `outputs`. + + Args: + batch (dict): Batch must consist minimally of `input_ids` as well as any other structure needed + to compute the metric. + output_logits (torch.Tensor): The model outputs evaluated on the batch `input_ids` + labels (torch.Tensor): The correct outputs. + + Raises: + NotImplementedError: Abstract method must be implemented by subclasses + """ + raise NotImplementedError + + +class InContextLearningGenerationExactMatchAccuracy(InContextLearningMetric): + r"""Computes exact match for in-context learning generation tasks. + + ICL generation tasks consist of some number of prompted generation tasks with correct answers + followed by a test task where the model must correctly produce one of a number of valid answers. + + For example, the model may be provided the context below and evaluated on its ability to correctly predict the continuation. + + Context: `Question: Who was president of the United States in 2012?\nAnswer: Barack Obama\nQuestion: Is water wet?\nAnswer: ` + Answers: [`yes`] + + The model will be expected to correctly produce one of the answers, following some optional normalization. + + Adds metric state variables: + correct (float): The number of instances where the prediction was a prefix for any of the answer aliases. + total (float): The number of total instances that were predicted. + + Args: + dist_sync_on_step (bool, optional): Synchronize metric state across processes at + each forward() before returning the value at the step. Default: ``False``. + """ + + # Make torchmetrics call update only once + full_state_update = False + + def __init__(self, dist_sync_on_step: bool = False): + # state from multiple processes + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.add_state('correct', + default=torch.tensor(0.), + dist_reduce_fx='sum') + self.add_state('total', default=torch.tensor(0.), dist_reduce_fx='sum') + self.metric_result_dict = { + 'cleaned_output': [], + 'original_label': [], + 'cleaned_label': [], + 'result': [], + } + + def normalize_answer(self, answer: str): + """Lower text and remove punctuation, articles and extra whitespace. + + Copied from https://github.com/mandarjoshi90/triviaqa/blob/master/evaluation/triviaqa_evaluation.py + """ + + def remove_articles(text: str) -> str: + return re.sub(r'\b(a|an|the)\b', ' ', text) + + def white_space_fix(text: str) -> str: + return ' '.join(text.split()) + + def handle_punc(text: str) -> str: + exclude = set(string.punctuation + + ''.join([u'‘', u'’', u'´', u'`'])) + return ''.join(ch if ch not in exclude else ' ' for ch in text) + + def lower(text: str) -> str: + return text.lower() + + def replace_underscore(text: str) -> str: + return text.replace('_', ' ') + + return white_space_fix( + remove_articles(handle_punc(lower( + replace_underscore(answer))))).strip() + + def update( + self, + batch: Dict[str, Any], + outputs: List[str], + labels: List[List[str]], + ): + cot_delimiter = batch.get('cot_delimiter', '') + do_normalization = batch.get('do_normalization', True) + stopping_criteria = batch.get('stopping_criteria', None) + metric_result_dict = copy.deepcopy(self.metric_result_dict) + for sample_output, sample_labels in zip(outputs, labels): + final_answer = sample_output + + if stopping_criteria is not None and len(stopping_criteria) > 0: + final_answer = re.split('|'.join(stopping_criteria), + final_answer)[0] + + if cot_delimiter is not None and len(cot_delimiter) > 0: + final_answer = final_answer.split(cot_delimiter)[-1] + + if do_normalization: + cleaned_final_answer = self.normalize_answer(final_answer) + cleaned_sample_labels = { + self.normalize_answer(label) for label in sample_labels + } + else: + # even if normalization is off, we should still strip leading/trailing whitespaces + cleaned_final_answer = final_answer.strip() + cleaned_sample_labels = { + sample_label.strip() for sample_label in sample_labels + } + metric_result_dict['original_label'].append(sample_labels) + metric_result_dict['cleaned_output'].append(cleaned_final_answer) + metric_result_dict['cleaned_label'].append(cleaned_sample_labels) + + if any( + cleaned_final_answer.startswith(label) + for label in cleaned_sample_labels): + self.correct += torch.tensor(1.0) + metric_result_dict['result'].append(1) + else: + metric_result_dict['result'].append(0) + + self.total += torch.tensor(1.0) + + return metric_result_dict + + def compute(self): + assert isinstance(self.correct, Tensor) + assert isinstance(self.total, Tensor) + return self.correct / self.total + + +class InContextLearningLMAccuracy(InContextLearningMetric): + r"""Computes accuracy for In-context learning language modeling tasks. + + ICL LM tasks consist of some number of example language modeling tasks (referred to as the 'context'), followed by a test task where the model must correctly predict all the tokens + following tokens in some passage (referred to as the 'continuation'). + + For example, the model may be provided the context below and evaluated on its ability to correctly predict the continuation. Note: it doesn't matter + whether the model correctly predicts the context tokens. + + Context: `The dog is->fuzzy\nthe water is->hot\nthe tree is->` + Continuation: `green` + + Adds metric state variables: + correct (float): The number of instances where the prediction masked the target. + total (float): The number of total instances that were predicted. + + Args: + dist_sync_on_step (bool, optional): Synchronize metric state across processes at + each forward() before returning the value at the step. Default: ``False``. + """ + + # Make torchmetrics call update only once + full_state_update = False + + def __init__(self, dist_sync_on_step: bool = False): + # state from multiple processes + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.add_state('correct', + default=torch.tensor(0.), + dist_reduce_fx='sum') + self.add_state('total', default=torch.tensor(0.), dist_reduce_fx='sum') + self.metric_result_dict = { + 'context': [], + 'label': [], + 'output': [], + 'result': [] + } + + def update(self, batch: dict, outputs: torch.Tensor, labels: torch.Tensor): + + metric_result_dict = copy.deepcopy(self.metric_result_dict) + for batch_idx, cont_idx in enumerate(batch['continuation_indices']): + cont_tok_pred = outputs[batch_idx].index_select(dim=0, + index=cont_idx - + 1).argmax(dim=-1) + cont_tok_targ = labels[batch_idx].index_select(dim=0, + index=cont_idx - 1) + + metric_result_dict['context'].append( + batch['input_ids'][batch_idx][:cont_idx[0]]) + metric_result_dict['label'].append(cont_tok_targ) + metric_result_dict['output'].append(cont_tok_pred) + + correct = (cont_tok_pred == cont_tok_targ).all().int() + self.correct += correct + metric_result_dict['result'].append(int(correct)) + + self.total += torch.tensor(1.0) + + return metric_result_dict + + def compute(self): + assert isinstance(self.correct, Tensor) + assert isinstance(self.total, Tensor) + return self.correct / self.total + + +class InContextLearningMultipleChoiceAccuracy(InContextLearningMetric): + r"""Computes accuracy for In-context learning multiple choice tasks. + + ICL MC tasks consists of a series of questions with some number of possible choices (only one of which can be correct). + At inference time each possible choice is given to the model as a separate input and the one for which the model assigns + the lowest perplexity to the choice is considered the model's choice. The model is correct if it "chooses" the right answer. + + Context: `The dog is->fuzzy\nthe water is->hot\nthe tree is->` + Continuation: `green` + + Adds metric state variables: + correct (float): The number of instances where the prediction masked the target. + total (float): The number of total instances that were predicted. + + Args: + dist_sync_on_step (bool, optional): Synchronize metric state across processes at + each forward() before returning the value at the step. Default: ``False``. + """ + + # Make torchmetrics call update only once + full_state_update = False + + def __init__(self, dist_sync_on_step: bool = False): + # state from multiple processes + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.add_state('correct', + default=torch.tensor(0.0), + dist_reduce_fx='sum') + self.add_state('total', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.metric_result_dict = { + 'context': [], + 'correct_choice': [], + 'correct_choice_idx': [], + 'selected_choice': [], + 'selected_choice_idx': [], + 'all_choices': [], + 'result': [], + } + + def update(self, batch: dict, outputs: torch.Tensor, labels: torch.Tensor): + + perplexities = [] + for batch_idx, cont_idx in enumerate(batch['continuation_indices']): + # continuation indices refer to indices in the original input's token space + cont_tok_logits = outputs[batch_idx].index_select(dim=0, + index=cont_idx - + 1) + # labels have been shifted left by one index, so the cont_idx needs to be shifted as well. + cont_tok_targ = labels[batch_idx].index_select(dim=0, + index=cont_idx - 1) + cross_entropy = F.cross_entropy(cont_tok_logits, cont_tok_targ) + perplexity = torch.exp(cross_entropy) + perplexities.append(perplexity) + + metric_result_dict = copy.deepcopy(self.metric_result_dict) + for (start, end), gold_idx in zip(batch['choice_groupings'], + batch['gold_indices']): + subset = perplexities[start:end] + idx_min = subset.index(min(subset)) + + if idx_min == gold_idx: + self.correct += torch.tensor(1.0) + metric_result_dict['result'].append(1) + else: + metric_result_dict['result'].append(0) + + question = batch['input_ids'][ + start][:batch['continuation_indices'][start][0]] + + correct_choice = batch['input_ids'][start:end][gold_idx][ + batch['continuation_indices'][start:end][gold_idx][0]: + batch['continuation_indices'][start:end][gold_idx][-1] + 1] + selected_choice = batch['input_ids'][start:end][idx_min][ + batch['continuation_indices'][start:end][idx_min][0]: + batch['continuation_indices'][start:end][idx_min][-1] + 1] + metric_result_dict['context'].append(question) + metric_result_dict['correct_choice'].append(correct_choice) + metric_result_dict['correct_choice_idx'].append(gold_idx) + metric_result_dict['selected_choice'].append(selected_choice) + metric_result_dict['selected_choice_idx'].append(idx_min) + all_choices = batch['input_ids'][start:end] + # Unpads the choices. Necessary in case different choices have different token lengths. + if 'attention_mask' in batch: + all_choices_list = [ + choice[batch['attention_mask'][i]] + for i, choice in enumerate(all_choices) + ] + metric_result_dict['all_choices'].append(all_choices_list) + + self.total += torch.tensor(1.0) + + # Don't return all_choices if we didn't fill it up (i.e. didn't use causal lms) + if metric_result_dict['all_choices'] == []: + metric_result_dict.pop('all_choices') + + return metric_result_dict + + def compute(self): + assert isinstance(self.correct, Tensor) + assert isinstance(self.total, Tensor) + return self.correct.float() / self.total + + +class InContextLearningCodeEvalAccuracy(InContextLearningMetric): + r"""Computes accuracy for In-context learning (ICL) code evaluation tasks. + + ICL code eval tasks consist of some number of example code eval tasks (referred to as the 'context'), followed by a test task where the model must + complete the code, where we term the code completion a 'continuation'. + + In each case, the model constructs a given number of continuations (termed pass@K for K continuations), and each continuation is run against a set of test cases. The model is considered + correct if at least one of the proposed continuations passes all the test cases. + + Runs on AWS Lambdas by default. + + Adds metric state variables: + correct (float): The number of instances where the predictions passed all the test cases. + total (float): The number of total instances that were predicted. + + Args: + dist_sync_on_step (bool, optional): Synchronize metric state across processes at + each forward() before returning the value at the step. Default: ``False``. + """ + + # Make torchmetrics call update only once + full_state_update = False + + def __init__(self, dist_sync_on_step: bool = False): + # state from multiple processes + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self._initialized = False + self.dataset_size = 0 + self.pass_at_k = [] + self.num_generations = 0 + self.eval_device = os.environ.get('CODE_EVAL_DEVICE', None) + if self.eval_device is not None: + self.eval_device = self.eval_device.upper() + self.metric_result_dict = { + 'context': [], + 'output': [], + 'result': [], + 'sample_id': [] + } + + def get_client(self) -> EvalClient: + """Returns a client for the appropriate remote platform.""" + client = None + if self.eval_device == 'LOCAL': + warnings.warn( + 'Running code eval locally may be insecure. Please set environment variable CODE_EVAL_DEVICE ' + + + 'to LAMBDA to run on remote. To use Lambdas, spin up your instance that checks code, set the URL as ' + + 'CODE_EVAL_URL and the API key as CODE_EVAL_APIKEY.') + log.debug('Running code eval locally.') + client = LocalEvalClient() + elif self.eval_device == 'LAMBDA': + client = LambdaEvalClient() + elif self.eval_device == 'MOSAICML': + client = MosaicMLLambdaEvalClient() + elif self.eval_device is None: + raise ValueError( + 'Attempting to use InContextLearningCodeEvalAccuracy but environment ' + + + 'variable `CODE_EVAL_DEVICE` is not set. Please set it to `CODE_EVAL_DEVICE` ' + + + 'to one of `LOCAL` (for unsafe local eval), `LAMBDA` (for AWS lambda ' + + 'evaluation), or `MOSAICML` (for lambda eval through MAPI).') + else: + raise ValueError( + 'Environment variable `CODE_EVAL_DEVICE` must be one of `LOCAL`, ' + + f'`LAMBDA`, or `MOSAICML` but got {self.eval_device}.') + + return client + + def estimator(self, n: int, c: int, k: int) -> float: + """Computes the pass@k metric. + + Given the number of generated samples, n, the number of correct samples, c, and the k of interest, + this function calculates pass@k as 1 - comb(n - c, k) / comb(n, k) as per the definition of + pass@k in the HumanEval paper (https://arxiv.org/abs/2107.03374) and it's associated implementation: + https://github.com/openai/human-eval. + """ + if n - c < k: + return 1.0 + return 1.0 - float(np.prod(1.0 - k / np.arange(n - c + 1, n + 1))) + + def _initialize_state(self, batch: dict[str, Any]): + device = batch['input_ids'].device + self.dataset_size = batch['dataset_size'] + self.pass_at_k = batch['pass_at_k'] + self.num_generations = batch['generations_per_sample'] + + # We need to defer the accumulator initialization because it depends on dataset size + self.add_state('correct', + default=torch.zeros(self.dataset_size, device=device), + dist_reduce_fx='sum') + self.add_state('total', + default=torch.zeros(self.dataset_size, device=device), + dist_reduce_fx='sum') + dist.barrier() + self._initialized = True + + def update(self, batch: Dict[str, Any], outputs: List[str], + labels: List[str]): + """Updates the pass@k accuracy of code generation. + + Given a batch of prompts, test cases, and code generations, evaluates the code generations + against the test cases and augments the pass@k accuracy of the batch to the values so far. + + Args: + batch (Dict[str, Any]): A batch of data produced by the InContextLearningCodeEvalDataset, with + the prompt, test cases, and entry points. This will be a dictionary that must have the following + arguments: + { + 'prompts': List[str], + 'test_inputs': List[List[str]], + 'test_outputs': List[List[str]], + 'entry_points': List[str], + 'languages': List[str], + 'generation_kwargs': Dict[str, Any] + } + outputs (List[str]): A list of code generations in the format of HF generate with beam search, + which is the a list of strings in groups of beam_size e.g. for beam size 2 and batch size 2, the list + will be of the format [prompt 1 gen 1, prompt 1 gen 2, prompt 2 gen 1, prompt 2 gen 2] + labels (List[str]): A list of the correct code generations, for compatibility with existing HF generate + functionalities. This is not used. + """ + if not self._initialized: + self._initialize_state(batch) + + del labels # never used + client = self.get_client() + + metric_result_dict = copy.deepcopy(self.metric_result_dict) + for sample_id, code_gen, sample_prompt, test_inputs, test_outputs, entry_point, language in zip( + batch['sample_id'], outputs, batch['prompts'], + batch['test_inputs'], batch['test_outputs'], + batch['entry_points'], batch['languages']): + + idx = sample_id + self.total[idx] += 1.0 + metric_result_dict['sample_id'].append(sample_id) + + code_gen = re.split( + r'\n[A-Za-z0-9#`]', + code_gen)[0] # remove everything after function ends + final_code = sample_prompt + code_gen # combine prompt with the code generation + metric_result_dict['context'].append(sample_prompt) + metric_result_dict['output'].append(code_gen) + + test_results = [] + for test_input, test_output in zip(test_inputs, test_outputs): + payload = { + 'code': final_code, + 'input': test_input, + 'output': test_output, + 'entry_point': entry_point, + 'language': language, + } + + result = client.invoke([[[payload]]])[0][0][0] + test_results.append(result) + + if all(test_results): + self.correct[idx] += 1.0 + metric_result_dict['result'].append(1) + else: + metric_result_dict['result'].append(0) + + client.close() # pyright: ignore [reportOptionalMemberAccess] + return metric_result_dict + + def compute(self): + assert isinstance(self.correct, Tensor) + assert isinstance(self.total, Tensor) + complete = self.total == self.num_generations # so that eval subset batches can be used + + if complete.sum() < (self.total != 0).sum(): + warnings.warn( + 'Some samples in the dataset have less than the expected number of generations. ' + + + 'This is expected if you are using a subset of the dataset for evaluation.' + ) + + if (self.correct > self.total).any().item(): + raise ValueError( + 'Internal error some samples have more correct than total generations. This should not happen.' + ) + + results = {} + n = self.num_generations + + for k in self.pass_at_k: + pass_at_k = sum([ + self.estimator(n, int(c.item()), k) + for c in self.correct[complete] + ]) / complete.sum().item() + results[f'pass@{k}'] = torch.tensor(pass_at_k) + + if len(results) == 1: # backwards compatibility + return list(results.values())[0] + + return results + + +class InContextLearningExpectedCalibrationError(InContextLearningMetric): + """Generic class for Expected Calibration Error (ECE). + + Citation: https://arxiv.org/pdf/1706.04599.pdf + + Expected calibration error is calculated by dividing predictions into buckets based on the model's confidence (a probability value between 0 and 1). + We then calculate the accuracy within each bucket and calculate the average gap between confidence and accuracy + across buckets, weighted by the number of samples in each bucket. + + Each task must implement its own definition of "confidence" to be computed via the `update` method. + + Adds metric state variables: + bucket_totals (float): The number of instances where the prediction masked the target per bucket. + bucket_correct (float): The number of total instances that were predicted per bucket. + + Args: + dist_sync_on_step (bool, optional): Synchronize metric state across processes at + each forward() before returning the value at the step. Default: ``False``. + n_buckets (int): Number of distinct buckets to split the confidence distribution into + """ + + def __init__(self, dist_sync_on_step: bool = False, n_buckets: int = 10): + # state from multiple processes + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.n_buckets = n_buckets + if n_buckets < 1: + raise Exception('`n_buckets`') + self.add_state('bucket_totals', + default=torch.zeros(n_buckets), + dist_reduce_fx='sum') + self.add_state('bucket_correct', + default=torch.zeros(n_buckets), + dist_reduce_fx='sum') + + def update(self, batch: dict, outputs: torch.Tensor, labels: torch.Tensor): + pass + + def compute(self): + assert isinstance(self.bucket_correct, Tensor) + assert isinstance(self.bucket_totals, Tensor) + + result = torch.tensor(0.0, device=self.bucket_correct.device) + total_obs = torch.sum(self.bucket_totals) + for i in range(self.n_buckets): + if self.bucket_totals[i] == 0: + continue + + acc_bucket_i = self.bucket_correct[i] / self.bucket_totals[i] + upper_bound = (i + 1) / self.n_buckets + lower_bound = i / self.n_buckets + conf_bucket_i = torch.tensor((upper_bound + lower_bound) / 2, + device=self.bucket_correct.device) + result += (self.bucket_totals[i] / + total_obs) * torch.abs(acc_bucket_i - conf_bucket_i) + return result + + +class InContextLearningMCExpectedCalibrationError( + InContextLearningExpectedCalibrationError): + r"""Computes Expected Calibration Error (ECE) for In-context learning (ICL) + + multiple choice (MC) tasks. (source: https://arxiv.org/abs/2012.00955). + + For MC tasks, the model confidence is defined as the softmax of average per-token probability assigned to the top question choice. + + See `InContextLearningExpectedCalibrationError` for more info. + """ + + # Make torchmetrics call update only once + full_state_update = False + + def update(self, batch: dict, outputs: torch.Tensor, labels: torch.Tensor): + + outputs = torch.softmax(outputs, dim=2) + probabilites = [] + for batch_idx, cont_idx in enumerate(batch['continuation_indices']): + cont_tok_logits = outputs[batch_idx].index_select(dim=0, + index=cont_idx - + 1) + cont_tok_targ = labels[batch_idx].index_select(dim=0, + index=cont_idx - 1) + probability = cont_tok_logits.index_select( + dim=1, index=cont_tok_targ).diagonal().mean() + probabilites.append(probability) + + for (start, end), gold_idx in zip(batch['choice_groupings'], + batch['gold_indices']): + subset = probabilites[start:end] + idx_max = subset.index(max(subset)) + confidence = torch.tensor(subset).max() / torch.tensor(subset).sum() + + assert confidence >= 0.0 and confidence <= 1.0 + bucket_idx = int(confidence * self.n_buckets) + if bucket_idx == self.n_buckets: + bucket_idx -= 1 + + if idx_max == gold_idx: + self.bucket_correct[ + bucket_idx] += 1 # pyright: ignore [reportGeneralTypeIssues] + + self.bucket_totals[ + bucket_idx] += 1 # pyright: ignore [reportGeneralTypeIssues] + + +class InContextLearningLMExpectedCalibrationError( + InContextLearningExpectedCalibrationError): + r"""Computes Expected Calibration Error (ECE) for In-context learning (ICL) + + language modeling (LM) tasks. (cite: https://arxiv.org/pdf/1706.04599.pdf). + + For LM tasks, the model confidence is defined as the minimum probability assigned to all tokens in the continuation. + + See `InContextLearningExpectedCalibrationError` for more info. + """ + + # Make torchmetrics call update only once + full_state_update = False + + def update(self, batch: dict, outputs: torch.Tensor, labels: torch.Tensor): + + outputs = torch.softmax(outputs, dim=2) + for batch_idx, cont_idx in enumerate(batch['continuation_indices']): + cont_tok_logits = outputs[batch_idx].index_select(dim=0, + index=cont_idx - + 1) + cont_tok_pred = cont_tok_logits.argmax(dim=-1) + confidence = cont_tok_logits.max(dim=-1).values.min() + cont_tok_targ = labels[batch_idx].index_select(dim=0, + index=cont_idx - 1) + assert confidence >= 0.0 and confidence <= 1.0 + bucket_idx = int(confidence * self.n_buckets) + if bucket_idx == self.n_buckets: + bucket_idx -= 1 + + if (cont_tok_pred == cont_tok_targ).all(): + self.bucket_correct[ + bucket_idx] += 1 # pyright: ignore [reportGeneralTypeIssues] + + self.bucket_totals[ + bucket_idx] += 1 # pyright: ignore [reportGeneralTypeIssues] diff --git a/llmfoundry/layers_registry.py b/llmfoundry/layers_registry.py index 9c7dabe128..24593144aa 100644 --- a/llmfoundry/layers_registry.py +++ b/llmfoundry/layers_registry.py @@ -1,20 +1,106 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import Type +from typing import Callable, Type import torch from llmfoundry.utils.registry_utils import create_registry -# Layers -_norm_description = """The norms registry is used to register classes that implement normalization layers.""" +_norm_description = ( + 'The norms registry is used to register classes that implement normalization layers.' +) norms = create_registry('llmfoundry', 'norms', generic_type=Type[torch.nn.Module], entry_points=True, description=_norm_description) +_fc_description = ( + 'The fully connected layers registry is used to register classes that implement fully connected layers (i.e. torch.nn.Linear).' + + + 'These classes should take in_features and out_features in as args, at a minimum.' +) +fcs = create_registry('llmfoundry', + 'fcs', + generic_type=Type[torch.nn.Module], + entry_points=True, + description=_fc_description) + +_ffns_description = ( + 'The ffns registry is used to register functions that build ffn layers.' + + 'See ffn.py for examples.') +ffns = create_registry('llmfoundry', + 'ffns', + generic_type=Callable, + entry_points=True, + description=_ffns_description) + +_ffns_with_norm_description = ( + 'The ffns_with_norm registry is used to register functions that build ffn layers that apply a normalization layer.' + + 'See ffn.py for examples.') +ffns_with_norm = create_registry('llmfoundry', + 'ffns_with_norm', + generic_type=Callable, + entry_points=True, + description=_ffns_with_norm_description) + +_ffns_with_megablocks_description = ( + 'The ffns_with_megablocks registry is used to register functions that build ffn layers using MegaBlocks.' + + 'See ffn.py for examples.') +ffns_with_megablocks = create_registry( + 'llmfoundry', + 'ffns_with_megablocks', + generic_type=Callable, + entry_points=True, + description=_ffns_with_megablocks_description) + +_attention_classes_description = ( + 'The attention_classes registry is used to register classes that implement attention layers. See ' + + 'attention.py for expected constructor signature.') +attention_classes = create_registry('llmfoundry', + 'attention_classes', + generic_type=Type[torch.nn.Module], + entry_points=True, + description=_attention_classes_description) + +_attention_implementations_description = ( + 'The attention_implementations registry is used to register functions that implement the attention operation.' + + 'See attention.py for expected function signature.') +attention_implementations = create_registry( + 'llmfoundry', + 'attention_implementations', + generic_type=Callable, + entry_points=True, + description=_attention_implementations_description) + +_param_init_fns_description = ( + 'The param_init_fns registry is used to register functions that initialize parameters.' + + + 'These will be called on a module to initialize its parameters. See param_init_fns.py for examples.' +) +param_init_fns = create_registry('llmfoundry', + 'param_init_fns', + generic_type=Callable[..., None], + entry_points=True, + description=_param_init_fns_description) + +_module_init_fns_description = """The module_init_fns registry is used to register functions that initialize specific modules. +These functions should return True if they initialize the module, and False otherwise. This allows them to be called without knowing their contents. +They should take in the module, init_div_is_residual, and div_is_residual arguments.""" +module_init_fns = create_registry('llmfoundry', + 'module_init_fns', + generic_type=Callable[..., bool], + entry_points=True, + description=_module_init_fns_description) __all__ = [ 'norms', + 'param_init_fns', + 'module_init_fns', + 'ffns', + 'ffns_with_norm', + 'ffns_with_megablocks', + 'attention_classes', + 'attention_implementations', + 'fcs', ] diff --git a/llmfoundry/metrics/__init__.py b/llmfoundry/metrics/__init__.py index 6c71a3ea08..e8310687a1 100644 --- a/llmfoundry/metrics/__init__.py +++ b/llmfoundry/metrics/__init__.py @@ -1,14 +1,15 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from composer.metrics import (InContextLearningCodeEvalAccuracy, - InContextLearningLMAccuracy, - InContextLearningLMExpectedCalibrationError, - InContextLearningMCExpectedCalibrationError, - InContextLearningMultipleChoiceAccuracy, - InContextLearningQAAccuracy, MaskedAccuracy) -from composer.metrics.nlp import LanguageCrossEntropy, LanguagePerplexity +from composer.metrics import (LanguageCrossEntropy, LanguagePerplexity, + MaskedAccuracy) +from llmfoundry.eval.metrics import ( + InContextLearningCodeEvalAccuracy, + InContextLearningGenerationExactMatchAccuracy, InContextLearningLMAccuracy, + InContextLearningLMExpectedCalibrationError, + InContextLearningMCExpectedCalibrationError, + InContextLearningMultipleChoiceAccuracy) from llmfoundry.metrics.token_acc import TokenAccuracy from llmfoundry.registry import metrics @@ -19,7 +20,8 @@ metrics.register('mc_expected_calibration_error', func=InContextLearningMCExpectedCalibrationError) metrics.register('mc_accuracy', func=InContextLearningMultipleChoiceAccuracy) -metrics.register('qa_accuracy', func=InContextLearningQAAccuracy) +metrics.register('qa_accuracy', + func=InContextLearningGenerationExactMatchAccuracy) metrics.register('code_eval_accuracy', func=InContextLearningCodeEvalAccuracy) metrics.register('language_cross_entropy', func=LanguageCrossEntropy) metrics.register('language_perplexity', func=LanguagePerplexity) @@ -54,7 +56,7 @@ 'InContextLearningLMExpectedCalibrationError', 'InContextLearningMCExpectedCalibrationError', 'InContextLearningMultipleChoiceAccuracy', - 'InContextLearningQAAccuracy', + 'InContextLearningGenerationExactMatchAccuracy', 'InContextLearningCodeEvalAccuracy', 'LanguageCrossEntropy', 'LanguagePerplexity', diff --git a/llmfoundry/models/inference_api_wrapper/interface.py b/llmfoundry/models/inference_api_wrapper/interface.py index 4c30e7822d..91f6fb2600 100644 --- a/llmfoundry/models/inference_api_wrapper/interface.py +++ b/llmfoundry/models/inference_api_wrapper/interface.py @@ -5,12 +5,12 @@ import torch from composer.core.types import Batch -from composer.metrics import InContextLearningMetric from composer.models import ComposerModel from omegaconf import DictConfig from torchmetrics import Metric from transformers import AutoTokenizer +from llmfoundry.eval.metrics import InContextLearningMetric from llmfoundry.metrics import DEFAULT_CAUSAL_LM_EVAL_METRICS diff --git a/llmfoundry/models/layers/__init__.py b/llmfoundry/models/layers/__init__.py index 262f190b47..dca55098c4 100644 --- a/llmfoundry/models/layers/__init__.py +++ b/llmfoundry/models/layers/__init__.py @@ -2,13 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 from llmfoundry.models.layers.attention import ( - ATTN_CLASS_REGISTRY, GroupedQueryAttention, MultiheadAttention, - MultiQueryAttention, attn_bias_shape, build_alibi_bias, build_attn_bias, - flash_attn_fn, scaled_multihead_dot_product_attention) + GroupedQueryAttention, MultiheadAttention, MultiQueryAttention, + attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn, + scaled_multihead_dot_product_attention) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding -from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY -from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn +from llmfoundry.models.layers.fc import * +from llmfoundry.models.layers.ffn import MPTMLP from llmfoundry.models.layers.norm import LPLayerNorm __all__ = [ @@ -20,12 +20,8 @@ 'attn_bias_shape', 'build_attn_bias', 'build_alibi_bias', - 'ATTN_CLASS_REGISTRY', 'MPTMLP', 'MPTBlock', 'LPLayerNorm', - 'FC_CLASS_REGISTRY', 'SharedEmbedding', - 'FFN_CLASS_REGISTRY', - 'build_ffn', ] diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index c24b3d4afa..6614d5d161 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -14,8 +14,9 @@ from packaging import version from torch import nn -from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY -from llmfoundry.models.layers.layer_builders import build_norm +from llmfoundry.layers_registry import (attention_classes, + attention_implementations) +from llmfoundry.models.layers.layer_builders import build_fc, build_norm def is_flash_v2_installed(v2_version: str = '2.0.0'): @@ -341,6 +342,7 @@ def flash_attn_fn( return output, None, past_key_value +@attention_classes.register_class('grouped_query_attention') class GroupedQueryAttention(nn.Module): """Grouped Query Attention (GQA) is a generalization of Multi-head (MHA). @@ -406,10 +408,11 @@ def __init__( 'bias': bias, } fc_kwargs['device'] = device - self.Wqkv = FC_CLASS_REGISTRY[fc_type]( - self.d_model, - self.d_model + 2 * self.kv_n_heads * self.head_dim, - **fc_kwargs, + self.Wqkv = build_fc( + name=fc_type, + in_features=self.d_model, + out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim, + fc_kwargs=fc_kwargs, ) # for param init fn; enables shape based init of fused layers fuse_splits = [ @@ -433,17 +436,13 @@ def __init__( device=device, ) - if self.attn_impl == 'flash': - self.attn_fn = flash_attn_fn - elif self.attn_impl == 'torch': - self.attn_fn = scaled_multihead_dot_product_attention - else: - raise ValueError(f'{attn_impl=} is an invalid setting.') + self.attn_fn = attention_implementations.get(self.attn_impl) - self.out_proj = FC_CLASS_REGISTRY[fc_type]( - self.d_model, - self.d_model, - **fc_kwargs, + self.out_proj = build_fc( + name=fc_type, + in_features=self.d_model, + out_features=self.d_model, + fc_kwargs=fc_kwargs, ) self.out_proj._is_residual = True @@ -572,6 +571,7 @@ def forward( return self.out_proj(context), attn_weights, past_key_value +@attention_classes.register_class('multihead_attention') class MultiheadAttention(GroupedQueryAttention): """Multi-head self attention. @@ -612,6 +612,7 @@ def __init__( ) +@attention_classes.register_class('multiquery_attention') class MultiQueryAttention(GroupedQueryAttention): """Multi-Query self attention. @@ -740,8 +741,6 @@ def build_alibi_bias( return alibi_bias.to(dtype=dtype) -ATTN_CLASS_REGISTRY = { - 'multihead_attention': MultiheadAttention, - 'multiquery_attention': MultiQueryAttention, - 'grouped_query_attention': GroupedQueryAttention -} +attention_implementations.register('flash', func=flash_attn_fn) +attention_implementations.register('torch', + func=scaled_multihead_dot_product_attention) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 18b9f979f4..40f349368f 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -8,9 +8,9 @@ import torch import torch.nn as nn -from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY -from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn -from llmfoundry.models.layers.layer_builders import build_norm +from llmfoundry.layers_registry import ffns_with_norm +from llmfoundry.models.layers.layer_builders import (build_attention_layer, + build_ffn, build_norm) try: from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip @@ -73,12 +73,15 @@ def __init__( del kwargs # unused, just to capture any extra args from the config super().__init__() + ffn_type = ffn_config['ffn_type'] + ffn_has_norm = ffn_type in ffns_with_norm + if self.fuse_norm_attn_norm: self.norm_attn_norm = FusedNormAttentionNorm( d_model=d_model, n_heads=n_heads, attn_config=attn_config, - ffn_config=ffn_config, + ffn_has_norm=ffn_has_norm, fc_type=fc_type, resid_pdrop=resid_pdrop, norm_type=norm_type, @@ -87,8 +90,6 @@ def __init__( ) else: assert isinstance(attn_config['attn_type'], str) - attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] - # Necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs args_to_exclude_in_attn_class = { 'attn_type', 'alibi', 'attn_uses_sequence_id', 'alibi_bias_max', @@ -106,17 +107,19 @@ def __init__( normalized_shape=d_model, device=device, ) - self.attn = attn_class( - d_model=d_model, - n_heads=n_heads, - fc_type=fc_type, - device=device, - **attn_config_subset_for_attn_class, - bias=not no_bias, + self.attn = build_attention_layer( + name=attn_config['attn_type'], + attn_kwargs={ + 'd_model': d_model, + 'n_heads': n_heads, + 'fc_type': fc_type, + 'device': device, + 'bias': not no_bias, + **attn_config_subset_for_attn_class + }, ) self.norm_2 = None - if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], - '_has_norm', False): + if not ffn_has_norm: self.norm_2 = build_norm( name=norm_type.lower(), normalized_shape=d_model, @@ -124,12 +127,14 @@ def __init__( ) self.ffn = build_ffn( + name=ffn_type, d_model=d_model, expansion_ratio=expansion_ratio, device=device, bias=not no_bias, - **ffn_config, + ffn_kwargs=ffn_config, ) + self.resid_attn_dropout = nn.Dropout(resid_pdrop) self.resid_ffn_dropout = nn.Dropout(resid_pdrop) self.use_pad_tok_in_ffn = use_pad_tok_in_ffn @@ -197,7 +202,7 @@ def __init__( d_model: int, n_heads: int, attn_config: Optional[Dict] = None, - ffn_config: Optional[Dict] = None, + ffn_has_norm: bool = False, fc_type: str = 'torch', resid_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', @@ -207,9 +212,7 @@ def __init__( ): super().__init__() assert attn_config is not None - assert ffn_config is not None assert isinstance(attn_config['attn_type'], str) - attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] # necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs args_to_exclude_in_attn_class = { @@ -227,17 +230,20 @@ def __init__( normalized_shape=d_model, device=device, ) - self.attn = attn_class( - d_model=d_model, - n_heads=n_heads, - fc_type=fc_type, - device=device, - **attn_config_subset_for_attn_class, - bias=not no_bias, + self.attn = build_attention_layer( + name=attn_config['attn_type'], + attn_kwargs={ + 'd_model': d_model, + 'n_heads': n_heads, + 'fc_type': fc_type, + 'device': device, + 'bias': not no_bias, + **attn_config_subset_for_attn_class + }, ) + self.norm_2 = None - if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm', - False): + if not ffn_has_norm: self.norm_2 = build_norm( name=norm_type.lower(), normalized_shape=d_model, diff --git a/llmfoundry/models/layers/dmoe.py b/llmfoundry/models/layers/dmoe.py index 1a981b61c5..19cd67b8aa 100644 --- a/llmfoundry/models/layers/dmoe.py +++ b/llmfoundry/models/layers/dmoe.py @@ -1,7 +1,7 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import Callable +from typing import Callable, Optional import torch @@ -24,7 +24,8 @@ class LearnedRouter(torch.nn.Module): def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int, moe_jitter_eps: float, moe_normalize_expert_weights: bool, - uniform_expert_assignment: bool, device: torch.device) -> None: + uniform_expert_assignment: bool, + device: Optional[torch.device]) -> None: super().__init__() self.hidden_size: int = hidden_size self.moe_num_experts: int = moe_num_experts @@ -84,7 +85,7 @@ def __init__( ffn_hidden_size: int, moe_num_experts: int, activation_fn: Callable, - device: torch.device, + device: Optional[torch.device], ) -> None: super().__init__() @@ -117,9 +118,14 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: class GLU(torch.nn.Module): - def __init__(self, hidden_size: int, ffn_hidden_size: int, - moe_num_experts: int, activation_fn: Callable, - device: torch.device): + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + moe_num_experts: int, + activation_fn: Callable, + device: Optional[torch.device], + ): super().__init__() self.hidden_size = hidden_size self.ffn_hidden_size = ffn_hidden_size @@ -157,9 +163,16 @@ def forward(self, x: torch.Tensor, expert_idx: torch.Tensor): class DroplessMLP(torch.nn.Module): - def __init__(self, hidden_size: int, ffn_hidden_size: int, mlp_type: str, - moe_num_experts: int, activation_fn: Callable, bias: bool, - device: torch.device): + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + mlp_type: str, + moe_num_experts: int, + activation_fn: Callable, + bias: bool, + device: Optional[torch.device], + ): super().__init__() self.moe_num_experts = moe_num_experts @@ -209,12 +222,20 @@ def forward(self, x: torch.Tensor, scores: torch.Tensor, class dMoE(torch.nn.Module): - def __init__(self, hidden_size: int, ffn_hidden_size: int, - moe_num_experts: int, moe_top_k: int, mlp_type: str, - activation_fn: Callable, moe_jitter_eps: float, - moe_normalize_expert_weights: bool, - uniform_expert_assignment: bool, bias: bool, - device: torch.device): + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + moe_num_experts: int, + moe_top_k: int, + mlp_type: str, + activation_fn: Callable, + moe_jitter_eps: float, + moe_normalize_expert_weights: bool, + uniform_expert_assignment: bool, + bias: bool, + device: Optional[torch.device], + ): super().__init__() # Token router. diff --git a/llmfoundry/models/layers/fc.py b/llmfoundry/models/layers/fc.py index b85bc133bd..8650e4966f 100644 --- a/llmfoundry/models/layers/fc.py +++ b/llmfoundry/models/layers/fc.py @@ -3,12 +3,12 @@ from torch import nn -FC_CLASS_REGISTRY = { - 'torch': nn.Linear, -} +from llmfoundry.layers_registry import fcs + +fcs.register('torch', func=nn.Linear) try: import transformer_engine.pytorch as te - FC_CLASS_REGISTRY['te'] = te.Linear + fcs.register('te', func=te.Linear) except: pass diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index 48d3d8c267..fb663b4c3c 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -10,10 +10,13 @@ import torch import torch.nn as nn +from torch.distributed import ProcessGroup from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Shard +from llmfoundry.layers_registry import (ffns, ffns_with_megablocks, + ffns_with_norm) from llmfoundry.models.layers.dmoe import dMoE -from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY +from llmfoundry.models.layers.layer_builders import build_fc try: import transformer_engine.pytorch as te @@ -52,7 +55,7 @@ def resolve_ffn_act_fn( config = deepcopy(config) name = config.pop('name') if not hasattr(torch.nn.functional, name): - raise ValueError(f'Unrecognised activation function name ({name}).') + raise ValueError(f'Unrecognized activation function name ({name}).') act = getattr(torch.nn.functional, name) return partial(act, **config) @@ -121,16 +124,18 @@ def __init__( self.fc_kwargs['device'] = device - self.up_proj = FC_CLASS_REGISTRY[fc_type]( - d_model, - ffn_hidden_size, - **self.fc_kwargs, + self.up_proj = build_fc( + name=fc_type, + in_features=d_model, + out_features=ffn_hidden_size, + fc_kwargs=self.fc_kwargs, ) self.act = act_fn - self.down_proj = FC_CLASS_REGISTRY[fc_type]( - ffn_hidden_size, - d_model, - **self.fc_kwargs, + self.down_proj = build_fc( + name=fc_type, + in_features=ffn_hidden_size, + out_features=d_model, + fc_kwargs=self.fc_kwargs, ) self.down_proj._is_residual = True @@ -159,35 +164,58 @@ def __init__( device=device, bias=bias, ) - self.gate_proj = FC_CLASS_REGISTRY[fc_type]( - d_model, - self.up_proj.out_features, - **self.fc_kwargs, + self.gate_proj = build_fc( + name=fc_type, + in_features=d_model, + out_features=self.up_proj.out_features, + fc_kwargs=self.fc_kwargs, ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) -FFN_CLASS_REGISTRY = { - 'mptmlp': MPTMLP, - 'mptglu': MPTGLU, - 'torch_dmoe': dMoE, -} +def build_mptglu( + d_model: int, + expansion_ratio: Union[int, float], + fc_type: str = 'torch', + ffn_hidden_size: Optional[int] = None, + ffn_act_fn: Optional[dict] = None, + device: Optional[str] = None, + bias: bool = True, +) -> nn.Module: + return MPTGLU( + d_model=d_model, + expansion_ratio=expansion_ratio, + fc_type=fc_type, + act_fn=resolve_ffn_act_fn(ffn_act_fn), + ffn_hidden_size=ffn_hidden_size, + device=device, + bias=bias, + ) -if is_te_imported: - import transformer_engine.pytorch as te - te.LayerNormMLP._has_norm = True - FFN_CLASS_REGISTRY['te_ln_mlp'] = te.LayerNormMLP -if is_megablocks_imported: - import megablocks - - FFN_CLASS_REGISTRY['mb_moe'] = megablocks.layers.moe.MoE - FFN_CLASS_REGISTRY['mb_dmoe'] = megablocks.layers.dmoe.dMoE +def build_mptmlp( + d_model: int, + expansion_ratio: Union[int, float], + fc_type: str = 'torch', + ffn_hidden_size: Optional[int] = None, + ffn_act_fn: Optional[dict] = None, + device: Optional[str] = None, + bias: bool = True, +) -> nn.Module: + return MPTMLP( + d_model=d_model, + expansion_ratio=expansion_ratio, + fc_type=fc_type, + act_fn=resolve_ffn_act_fn(ffn_act_fn), + ffn_hidden_size=ffn_hidden_size, + device=device, + bias=bias, + ) -def build_ffn( +def build_te_ln_mlp( d_model: int, expansion_ratio: Union[int, float], fc_type: str = 'torch', @@ -197,131 +225,225 @@ def build_ffn( bias: bool = True, **kwargs: Any, ) -> nn.Module: - ffn_type = kwargs.pop('ffn_type') - if ffn_type in ['mptmlp', 'mptglu']: - if len(kwargs) > 0: - raise ValueError( - f'MPTMLP (or MPTGLU) got an unexpected keyword argument: {kwargs}' - ) - return FFN_CLASS_REGISTRY[ffn_type]( - d_model=d_model, - expansion_ratio=expansion_ratio, - fc_type=fc_type, - act_fn=resolve_ffn_act_fn(ffn_act_fn), - ffn_hidden_size=ffn_hidden_size, - device=device, - bias=bias, + assert te is not None + ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, + ffn_hidden_size) + if ffn_act_fn is not None: + raise ValueError( + f'Transformer Engine block does not support custom activation functions.' ) - elif ffn_type == 'te_ln_mlp': - if te is None: - raise RuntimeError( - 'Requirements for TransformerEngine not installed; see install instructions in `README.md`.' - ) - ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, - ffn_hidden_size) - if ffn_act_fn is not None: - raise ValueError( - f'Transformer Engine block does not support custom activation functions.' - ) - return te.LayerNormMLP( - hidden_size=d_model, - ffn_hidden_size=ffn_hidden_size, - bias=bias, - **kwargs, - ) - elif ffn_type in ('mb_moe', 'mb_dmoe'): - if megablocks is None: - raise RuntimeError( - 'Requirements for megablocks not installed; see install instructions in `README.md`.' - ) - args = kwargs['args'] - args.bias = bias - args.hidden_size = d_model - args.device = device + return te.LayerNormMLP( + hidden_size=d_model, + ffn_hidden_size=ffn_hidden_size, + bias=bias, + **kwargs, + ) + + +def build_torch_dmoe( + d_model: int, + expansion_ratio: Union[int, float], + ffn_hidden_size: Optional[int] = None, + ffn_act_fn: Optional[dict] = None, + device: Optional[str] = None, + bias: bool = True, + **kwargs: Any, +) -> nn.Module: + moe_num_experts = kwargs.pop('moe_num_experts') + moe_top_k = kwargs.pop('moe_top_k') + mlp_type = kwargs.pop('mlp_type') + moe_jitter_eps = kwargs.pop('moe_jitter_eps') + moe_normalize_expert_weights = kwargs.pop('moe_normalize_expert_weights') + uniform_expert_assignment = kwargs.pop('uniform_expert_assignment') + + fc_type = kwargs.pop('fc_type', 'torch') + del fc_type # Unused + + if len(kwargs) > 0: + raise ValueError(f'Invalid arguments to torch dmoe: {kwargs}.') + + return dMoE( + hidden_size=d_model, + ffn_hidden_size=resolve_ffn_hidden_size(d_model, expansion_ratio, + ffn_hidden_size), + moe_num_experts=moe_num_experts, + moe_top_k=moe_top_k, + mlp_type=mlp_type, + bias=bias, + moe_jitter_eps=moe_jitter_eps, + activation_fn=resolve_ffn_act_fn(ffn_act_fn), + moe_normalize_expert_weights=moe_normalize_expert_weights, + uniform_expert_assignment=uniform_expert_assignment, + device=torch.device(device) if device is not None else None, + ) - ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, - ffn_hidden_size) - args.ffn_hidden_size = ffn_hidden_size - - if ffn_act_fn is not None: - args.activation_fn = resolve_ffn_act_fn(ffn_act_fn) - - moe_world_size = 1 - expert_parallel_group = args.expert_parallel_group - if expert_parallel_group is not None: - moe_world_size = expert_parallel_group.size() - if kwargs.get('moe_world_size') != moe_world_size: - raise RuntimeError( - f'MoE expert_parallel_group configured with incorrect world size.' - ) - if ffn_type == 'mb_moe': - ffn = megablocks.layers.moe.MoE(args) - - # Fused initialization setup - # For param_init_fn, enables shape based init of stacked layers - ffn.experts.mlp._stack_dim = 0 - elif ffn_type == 'mb_dmoe': - ffn = megablocks.layers.dmoe.dMoE(args) - - # Fused initialization setup - # For param_init_fn, enables shape based init of fused layers - n_exp = min(1, args.moe_num_experts // moe_world_size) - ffn.experts.mlp._fused = (0, [ - (n + 1) * args.ffn_hidden_size for n in range(n_exp - 1) - ]) +def _mb_setup_args( + d_model: int, + expansion_ratio: Union[int, float], + ffn_hidden_size: Optional[int], + ffn_act_fn: Optional[dict], + device: Optional[str], + bias: bool, + kwargs: dict[str, Any], +) -> tuple['megablocks.layers.arguments.Arguments', int, ProcessGroup]: + if megablocks is None: + raise RuntimeError( + 'Requirements for megablocks not installed; see install instructions in `README.md`.' + ) + args = kwargs['args'] + args.bias = bias + args.hidden_size = d_model + args.device = device + + ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, + ffn_hidden_size) + args.ffn_hidden_size = ffn_hidden_size + + if ffn_act_fn is not None: + args.activation_fn = resolve_ffn_act_fn(ffn_act_fn) + + moe_world_size = 1 + expert_parallel_group = args.expert_parallel_group + if expert_parallel_group is not None: + moe_world_size = expert_parallel_group.size() + if kwargs.get('moe_world_size') != moe_world_size: + raise RuntimeError( + f'MoE expert_parallel_group configured with incorrect world size.') + + return args, moe_world_size, expert_parallel_group + + +def _patch_ffn_mb( + ffn: nn.Module, + moe_world_size: int, + expert_parallel_group: ProcessGroup, + device_mesh: DeviceMesh, + args: 'megablocks.layers.arguments.Arguments', +): + # Attach args to MLP directly for use in param_init_fn + ffn.experts.mlp.hidden_size = args.ffn_hidden_size + ffn.experts.mlp.expert_parallel_group = expert_parallel_group + ffn.experts.mlp.weight_parallel_group = args.weight_parallel_group + + if moe_world_size > 1: + expert_mesh = device_mesh['expert_parallel'] + expert_placements: List[Placement] = [Shard(0)] + # Register in two loops as you cannot overwrite parameters while iterating over named_parameters() + dtensorified_params = [ + (name, + dtensorify_param(param=parameter, + mesh=expert_mesh, + placements=expert_placements)) + for name, parameter in ffn.experts.mlp.named_parameters() + ] + for name, dtensorified_param in dtensorified_params: + ffn.experts.mlp.register_parameter(name, dtensorified_param) + + if device_mesh.mesh.ndim == 2: + submesh = device_mesh['weight_parallel'] + elif device_mesh.mesh.ndim == 3: + raise RuntimeError(f'HSDP + MoE is not supported.') else: - raise RuntimeError(f'Invalid ffn_type option: {ffn_type}.') - - # Attach args to MLP directly for use in param_init_fn - ffn.experts.mlp.hidden_size = args.ffn_hidden_size - ffn.experts.mlp.expert_parallel_group = expert_parallel_group - ffn.experts.mlp.weight_parallel_group = args.weight_parallel_group - - if moe_world_size > 1: - device_mesh = kwargs['device_mesh'] - - expert_mesh = device_mesh['expert_parallel'] - expert_placements: List[Placement] = [Shard(0)] - # Register in two loops as you cannot overwrite parameters while iterating over named_parameters() - dtensorified_params = [ - (name, - dtensorify_param(param=parameter, - mesh=expert_mesh, - placements=expert_placements)) - for name, parameter in ffn.experts.mlp.named_parameters() - ] - for name, dtensorified_param in dtensorified_params: - ffn.experts.mlp.register_parameter(name, dtensorified_param) - - device_mesh = kwargs['device_mesh'] - if device_mesh.mesh.ndim == 2: - submesh = device_mesh['weight_parallel'] - elif device_mesh.mesh.ndim == 3: - raise RuntimeError(f'HSDP + MoE is not supported.') - else: - raise ValueError( - f'{device_mesh.mesh.ndim=} not supported for MoE.') - - ffn.experts._fsdp_kwargs_dict = { - 'device_mesh': submesh, - } - return ffn - elif ffn_type == 'torch_dmoe': - return dMoE( - hidden_size=d_model, - ffn_hidden_size=resolve_ffn_hidden_size(d_model, expansion_ratio, - ffn_hidden_size), - moe_num_experts=kwargs.pop('moe_num_experts'), - moe_top_k=kwargs.pop('moe_top_k'), - mlp_type=kwargs.pop('mlp_type'), - bias=bias, - moe_jitter_eps=kwargs.pop('moe_jitter_eps'), - activation_fn=resolve_ffn_act_fn(ffn_act_fn), - moe_normalize_expert_weights=kwargs.pop( - 'moe_normalize_expert_weights'), - uniform_expert_assignment=kwargs.pop('uniform_expert_assignment'), - device=device, # pyright: ignore[reportGeneralTypeIssues] + raise ValueError(f'{device_mesh.mesh.ndim=} not supported for MoE.') + + ffn.experts._fsdp_kwargs_dict = { + 'device_mesh': submesh, + } + + +def build_mb_moe( + d_model: int, + expansion_ratio: Union[int, float], + ffn_hidden_size: Optional[int] = None, + ffn_act_fn: Optional[dict] = None, + device: Optional[str] = None, + bias: bool = True, + **kwargs: Any, +) -> nn.Module: + if not is_megablocks_imported: + raise RuntimeError( + 'Requirements for megablocks not installed; see install instructions in `README.md`.' + ) + + args, moe_world_size, expert_parallel_group = _mb_setup_args( + d_model=d_model, + expansion_ratio=expansion_ratio, + ffn_hidden_size=ffn_hidden_size, + ffn_act_fn=ffn_act_fn, + device=device, + bias=bias, + kwargs=kwargs, + ) + + ffn = megablocks.layers.moe.MoE(args) + + # Fused initialization setup + # For param_init_fn, enables shape based init of stacked layers + ffn.experts.mlp._stack_dim = 0 + + _patch_ffn_mb( + ffn=ffn, + moe_world_size=moe_world_size, + expert_parallel_group=expert_parallel_group, + device_mesh=kwargs['device_mesh'], + args=args, + ) + + return ffn + + +def build_mb_dmoe( + d_model: int, + expansion_ratio: Union[int, float], + ffn_hidden_size: Optional[int] = None, + ffn_act_fn: Optional[dict] = None, + device: Optional[str] = None, + bias: bool = True, + **kwargs: Any, +) -> nn.Module: + if not is_megablocks_imported: + raise RuntimeError( + 'Requirements for megablocks not installed; see install instructions in `README.md`.' ) - raise ValueError(f'{ffn_type=} not recognized.') + args, moe_world_size, expert_parallel_group = _mb_setup_args( + d_model=d_model, + expansion_ratio=expansion_ratio, + ffn_hidden_size=ffn_hidden_size, + ffn_act_fn=ffn_act_fn, + device=device, + bias=bias, + kwargs=kwargs, + ) + + ffn = megablocks.layers.dmoe.dMoE(args) + + # Fused initialization setup + # For param_init_fn, enables shape based init of fused layers + n_exp = min(1, args.moe_num_experts // moe_world_size) + ffn.experts.mlp._fused = (0, [ + (n + 1) * args.ffn_hidden_size for n in range(n_exp - 1) + ]) + + _patch_ffn_mb( + ffn=ffn, + moe_world_size=moe_world_size, + expert_parallel_group=expert_parallel_group, + device_mesh=kwargs['device_mesh'], + args=args, + ) + + return ffn + + +ffns.register('mptglu', func=build_mptglu) +ffns.register('mptmlp', func=build_mptmlp) +ffns.register('torch_dmoe', func=build_torch_dmoe) + +if is_te_imported: + ffns_with_norm.register('te_ln_mlp', func=build_te_ln_mlp) + +if is_megablocks_imported: + ffns_with_megablocks.register('mb_moe', func=build_mb_moe) + ffns_with_megablocks.register('mb_dmoe', func=build_mb_dmoe) diff --git a/llmfoundry/models/layers/layer_builders.py b/llmfoundry/models/layers/layer_builders.py index 23f5b89668..425fcaf862 100644 --- a/llmfoundry/models/layers/layer_builders.py +++ b/llmfoundry/models/layers/layer_builders.py @@ -1,11 +1,13 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch -from llmfoundry.layers_registry import norms +from llmfoundry.layers_registry import (attention_classes, fcs, ffns, + ffns_with_megablocks, ffns_with_norm, + norms) from llmfoundry.utils.registry_utils import construct_from_registry @@ -23,3 +25,75 @@ def build_norm( registry=norms, pre_validation_function=torch.nn.Module, kwargs=kwargs) + + +def build_ffn( + name: str, + d_model: int, + expansion_ratio: float, + device: Optional[str], + bias: bool, + ffn_kwargs: Dict[str, Any], +): + + registry_to_use = ffns + if name in ffns_with_norm: + registry_to_use = ffns_with_norm + + if name in ffns_with_megablocks: + registry_to_use = ffns_with_megablocks + + kwargs = { + 'd_model': d_model, + 'expansion_ratio': expansion_ratio, + 'device': device, + 'bias': bias, + **{k: v for k, v in ffn_kwargs.items() if k != 'ffn_type'}, + } + + def _validation_function(maybe_module: Any): + if not isinstance(maybe_module, torch.nn.Module): + raise ValueError(f'Function {name} must return a torch.nn.Module.') + + result = construct_from_registry( + name=name, + registry=registry_to_use, + post_validation_function=_validation_function, + partial_function=False, + kwargs=kwargs) + + if name in ffns_with_norm: + result._has_norm = True + + if name in ffns_with_megablocks: + result._uses_megablocks = True + + return result + + +def build_attention_layer( + name: str, + attn_kwargs: Dict[str, Any], +): + return construct_from_registry(name=name, + registry=attention_classes, + pre_validation_function=torch.nn.Module, + kwargs=attn_kwargs) + + +def build_fc( + name: str, + in_features: int, + out_features: int, + fc_kwargs: Dict[str, Any], +): + kwargs = { + 'in_features': in_features, + 'out_features': out_features, + **fc_kwargs, + } + + return construct_from_registry(name=name, + registry=fcs, + pre_validation_function=torch.nn.Module, + kwargs=kwargs) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 8383d33ec0..dbee232f3d 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -8,6 +8,7 @@ from transformers import PretrainedConfig +from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.models.layers.attention import (check_alibi_support, is_flash_v2_installed) from llmfoundry.models.layers.blocks import attn_config_defaults @@ -16,11 +17,9 @@ # HuggingFace can detect all the needed files to copy into its modules folder. # Otherwise, certain modules are missing. # isort: off -from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY # type: ignore (see note) from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note) -from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note) +from llmfoundry.models.layers.layer_builders import build_norm, build_fc, build_ffn # type: ignore (see note) from llmfoundry.models.layers.dmoe import dMoE # type: ignore (see note) -from llmfoundry.models.layers.layer_builders import build_norm # type: ignore (see note) from llmfoundry.layers_registry import norms # type: ignore (see note) from llmfoundry.utils.registry_utils import construct_from_registry # type: ignore (see note) @@ -291,7 +290,7 @@ def _validate_config(self) -> None: ) elif self.ffn_config['ffn_type'] in ['mptmlp', 'mptglu']: self.ffn_config['fc_type'] = self.fc_type - elif self.ffn_config['ffn_type'] in ['mb_moe', 'mb_dmoe']: + elif self.ffn_config['ffn_type'] in ffns_with_megablocks: self.ffn_config['return_bias'] = False elif self.ffn_config['ffn_type'] == 'te_ln_mlp': self.ffn_config['bias'] = not self.no_bias diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 4a8f3943af..1ef62a3b19 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -20,6 +20,7 @@ from composer.models import HuggingFaceModel from composer.utils import dist +from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.models.layers.attention import is_flash_v2_installed if is_flash_v2_installed(): @@ -42,12 +43,11 @@ from transformers.models.llama.modeling_llama import \ LlamaRotaryEmbedding as HFRotaryEmbedding -from llmfoundry.layers_registry import norms +from llmfoundry.layers_registry import norms, param_init_fns from llmfoundry.models.layers.attention import (attn_bias_shape, build_attn_bias, gen_slopes) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding -from llmfoundry.models.layers.ffn import build_ffn as build_ffn from llmfoundry.models.layers.layer_builders import build_norm from llmfoundry.models.mpt.configuration_mpt import MPTConfig from llmfoundry.models.utils.config_moe_args import config_moe_args @@ -62,8 +62,8 @@ init_empty_weights # type: ignore (see note) from llmfoundry.models.utils.param_init_fns import ( generic_param_init_fn_, # type: ignore (see note) - MODEL_INIT_REGISTRY, ) +from llmfoundry.models.layers.ffn import resolve_ffn_act_fn # type: ignore (see note) from llmfoundry.models.utils.act_ckpt import (pass_on_block_idx, build_act_ckpt_mod_to_blocks, @@ -324,7 +324,7 @@ def __init__(self, config: MPTConfig): self.emb_drop = nn.Dropout(config.emb_pdrop) self.mb_args = None block_args = config.to_dict() - if block_args['ffn_config']['ffn_type'] in ('mb_moe', 'mb_dmoe'): + if block_args['ffn_config']['ffn_type'] in ffns_with_megablocks: block_args['ffn_config'] = config_moe_args( block_args['ffn_config'], config.d_model, @@ -332,6 +332,7 @@ def __init__(self, config: MPTConfig): config.n_layers, ) self.mb_args = block_args['ffn_config'].get('args') + self.blocks = nn.ModuleList([ MPTBlock( device=config.init_device, @@ -676,7 +677,7 @@ def forward( # Param Initialization, needed for device='meta' fast initialization def param_init_fn(self, module: nn.Module) -> None: init_fn_name = self.config.init_config['name'] - MODEL_INIT_REGISTRY[init_fn_name]( + param_init_fns.get(init_fn_name)( module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, @@ -836,7 +837,7 @@ def forward( # Param Initialization, needed for device='meta' fast initialization def param_init_fn(self, module: nn.Module) -> None: init_fn_name = self.config.init_config['name'] - MODEL_INIT_REGISTRY[init_fn_name]( + param_init_fns.get(init_fn_name)( module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, @@ -1026,7 +1027,7 @@ def get_targets(self, batch: Mapping) -> torch.Tensor: return targets def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: - if self.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'): + if self.config.ffn_config['ffn_type'] in ffns_with_megablocks: # Clear MegaBlocks MoE load balancing loss cache try: # Add try/catch to avoid transformers complaining and raising errors from megablocks.layers.moe import clear_load_balancing_loss @@ -1053,7 +1054,7 @@ def loss(self, outputs: CausalLMOutputWithPast, else: loss = losses.sum() / (targets != self.loss_fn.ignore_index).sum() - if self.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'): + if self.config.ffn_config['ffn_type'] in ffns_with_megablocks: # MegaBlocks MoE load balancing loss try: # Add try/catch to avoid transformers complaining and raising errors from megablocks.layers.moe import batched_load_balancing_loss diff --git a/llmfoundry/models/utils/__init__.py b/llmfoundry/models/utils/__init__.py index 41313b8729..ca5fa4b935 100644 --- a/llmfoundry/models/utils/__init__.py +++ b/llmfoundry/models/utils/__init__.py @@ -6,14 +6,12 @@ init_on_device) from llmfoundry.models.utils.mpt_param_count import (mpt_get_active_params, mpt_get_total_params) -from llmfoundry.models.utils.param_init_fns import (MODEL_INIT_REGISTRY, - generic_param_init_fn_) +from llmfoundry.models.utils.param_init_fns import generic_param_init_fn_ __all__ = [ 'init_empty_weights', 'init_on_device', 'generic_param_init_fn_', - 'MODEL_INIT_REGISTRY', 'config_moe_args', 'mpt_get_active_params', 'mpt_get_total_params', diff --git a/llmfoundry/models/utils/act_ckpt.py b/llmfoundry/models/utils/act_ckpt.py index 1975865f1b..e6cd8bdc58 100644 --- a/llmfoundry/models/utils/act_ckpt.py +++ b/llmfoundry/models/utils/act_ckpt.py @@ -5,10 +5,10 @@ import torch -from llmfoundry.layers_registry import norms -from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY +from llmfoundry.layers_registry import (attention_classes, ffns, + ffns_with_megablocks, ffns_with_norm, + norms) from llmfoundry.models.layers.blocks import FusedNormAttentionNorm, MPTBlock -from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY def pass_on_block_idx(parent: torch.nn.Module): @@ -25,18 +25,24 @@ def get_act_ckpt_module(mod_name: str) -> Any: """Get the module type from the module name.""" if mod_name.lower() == 'mptblock': mod_type = MPTBlock + elif mod_name in attention_classes: + mod_type = attention_classes.get(mod_name) elif mod_name.lower() == 'norm_attn_norm': mod_type = FusedNormAttentionNorm - elif mod_name in ATTN_CLASS_REGISTRY: - mod_type = ATTN_CLASS_REGISTRY[mod_name] - elif mod_name in FFN_CLASS_REGISTRY: - mod_type = FFN_CLASS_REGISTRY[mod_name] + elif mod_name in ffns: + mod_type = ffns.get(mod_name) + elif mod_name in ffns_with_norm: + mod_type = ffns_with_norm.get(mod_name) + elif mod_name in ffns_with_megablocks: + mod_type = ffns_with_megablocks.get(mod_name) elif mod_name in norms: mod_type = norms.get(mod_name) else: msg = ', '.join( - list(ATTN_CLASS_REGISTRY.keys()) + list(FFN_CLASS_REGISTRY.keys()) + - list(norms.get_all()) + ['MPTBlock']) + list(attention_classes.keys()) + list(ffns.get_all()) + + list(ffns_with_norm.get_all()) + + list(ffns_with_megablocks.get_all()) + list(norms.get_all()) + + ['MPTBlock']) raise ValueError( f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.' ) diff --git a/llmfoundry/models/utils/config_moe_args.py b/llmfoundry/models/utils/config_moe_args.py index b69cd18348..4de9a47bbc 100644 --- a/llmfoundry/models/utils/config_moe_args.py +++ b/llmfoundry/models/utils/config_moe_args.py @@ -9,9 +9,31 @@ from packaging import version from torch import distributed +from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.models.layers.ffn import resolve_ffn_hidden_size +def create_process_group_ranks(ranks: tuple[int]): + """Creates a new distributed group. + + Used in create_set_process_group and create_mod_process_group methods below. + + This function is an alternative to `distributed.new_group(ranks)`. + + Args: + ranks (tuple[int]): Tuple of ranks of group members. + + Returns: + A handle of distributed group that can be given to collective calls. + """ + ranks_gather_list = [None for _ in range(distributed.get_world_size())] + distributed.all_gather_object(ranks_gather_list, ranks) + ranks_per_subgroup = list(set(ranks_gather_list)) + group, _ = distributed.distributed_c10d.new_subgroups_by_enumeration( + ranks_per_subgroup) + return group + + def create_set_process_group(k: int): """Creates a new distributed group using sets of k GPUs. @@ -33,7 +55,7 @@ def create_set_process_group(k: int): raise RuntimeError(f'{world_size=} must be divisible by {k=}.') start = distributed.get_rank() // k * k ranks = tuple(range(start, start + k)) - return distributed.new_group(ranks) + return create_process_group_ranks(ranks) def config_megablocks_moe_args( @@ -156,7 +178,7 @@ def config_moe_args( Returns: ffn_config (dict): FFN configuration with MoE configured. """ - if ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'): + if ffn_config['ffn_type'] in ffns_with_megablocks: return config_megablocks_moe_args( ffn_config=ffn_config, d_model=d_model, diff --git a/llmfoundry/models/utils/mpt_param_count.py b/llmfoundry/models/utils/mpt_param_count.py index d90929713b..cb1a5c0935 100644 --- a/llmfoundry/models/utils/mpt_param_count.py +++ b/llmfoundry/models/utils/mpt_param_count.py @@ -16,6 +16,8 @@ from torch import Tensor, nn from torch.distributed._tensor import DTensor +from llmfoundry.layers_registry import ffns_with_megablocks + def module_n_params(module: nn.Module) -> int: """Gets the number of parameters in this module excluding child modules. @@ -127,7 +129,7 @@ def megablocks_n_active_params(mpt_model) -> int: # type: ignore def mpt_get_total_params(mpt_model) -> int: # type: ignore - """Calculates the total paramter count of an MPT model. + """Calculates the total parameter count of an MPT model. Note: Must be called before model parameters are sharded by FSDP. @@ -138,14 +140,14 @@ def mpt_get_total_params(mpt_model) -> int: # type: ignore Returns: An int for the total number of parameters in this MPT model. """ - if mpt_model.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'): + if mpt_model.config.ffn_config['ffn_type'] in ffns_with_megablocks: return megablocks_n_total_params(mpt_model) else: return sum(p.numel() for p in mpt_model.parameters()) def mpt_get_active_params(mpt_model) -> int: # type: ignore - """Calculates the total paramter count of an MPT model. + """Calculates the total parameter count of an MPT model. Note: Must be called before model parameters are sharded by FSDP. @@ -156,7 +158,7 @@ def mpt_get_active_params(mpt_model) -> int: # type: ignore Returns: An int for the active number of parameters in this MPT model. """ - if mpt_model.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'): + if mpt_model.config.ffn_config['ffn_type'] in ffns_with_megablocks: params = megablocks_n_active_params(mpt_model) else: params = sum(p.numel() for p in mpt_model.parameters()) diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index 16376de451..e64cde5e96 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -12,9 +12,9 @@ from torch import nn from torch.distributed._tensor import DTensor -from llmfoundry.layers_registry import norms +from llmfoundry.layers_registry import (fcs, module_init_fns, norms, + param_init_fns) from llmfoundry.models.layers.dmoe import GLU, MLP -from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY try: import transformer_engine.pytorch as te @@ -54,7 +54,7 @@ def fused_init_helper_( Args: module (nn.Module): The module to initialize. init_fn_ (Callable): Initialization method. - name_param (str): Name of parameter to initalize within the module. + name_param (str): Name of parameter to initialize within the module. """ _fused = getattr(module, '_fused', None) if _fused is None: @@ -90,7 +90,7 @@ def stacked_init_helper_( init_fn_: Callable, name_param: str = 'weight', ): - """Initializes parameters stacked along a new dimention. + """Initializes parameters stacked along a new dimension. Parameter initialization is often based on the parameters shape. If a layer is stacked, initialization should be based on the shapes of the original tensor instead of the @@ -100,7 +100,7 @@ def stacked_init_helper_( Args: module (nn.Module): The module to initialize. init_fn_ (Callable): Initialization method. - name_param (str): Name of parameter to initalize within the module. + name_param (str): Name of parameter to initialize within the module. """ stack_dim = getattr(module, '_stack_dim', None) if stack_dim is None: @@ -114,12 +114,12 @@ def stacked_param_init_helper( init_fn_: Callable, stack_dim: int, ): - """Initialize parameters stacked along a new dimention. + """Initialize parameters stacked along a new dimension. Args: param (torch.Tensor): Tensor to initialize. init_fn_ (Callable): Initialization method. - stack_dim (int): Dimention along with parameters are stacked + stack_dim (int): Dimension along with parameters are stacked """ p_ndims = param.ndim @@ -148,41 +148,16 @@ def _flip_fan_mode(init_fn_: Callable): return _init_fn_ -def generic_param_init_fn_( +def fc_init( module: nn.Module, init_fn_: Callable, - n_layers: int, - d_model: Optional[int] = None, - init_div_is_residual: Union[int, float, str, bool] = True, - emb_init_std: Optional[float] = None, - emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, + init_div_is_residual: Union[int, float, str, bool], + div_is_residual: Optional[float], **kwargs: Any, -) -> None: - del kwargs # unused, just to capture any extra args from the config - # enable user to divide _is_residual weights by - - # a value which defaults to math.sqrt(2 * cfg.n_layers) - init_div_is_residual = init_div_is_residual +) -> bool: + del kwargs # unused, just to capture any extra args - if init_div_is_residual is False: - # not used, for pyright - div_is_residual = 1.0 - elif init_div_is_residual is True: - div_is_residual = math.sqrt(2 * n_layers) - elif isinstance(init_div_is_residual, float) or isinstance( - init_div_is_residual, int): - div_is_residual = init_div_is_residual - elif init_div_is_residual.isnumeric(): - # do not trust YAML parsing to always convert numbers to numbers - div_is_residual = float(init_div_is_residual) - else: - # not used, for pyright - div_is_residual = 1.0 - raise ValueError( - f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}' - ) - - if isinstance(module, tuple(set(FC_CLASS_REGISTRY.values()))): + if isinstance(module, tuple(set([fcs.get(n) for n in fcs.get_all()]))): # Linear if hasattr(module, '_fused'): fused_init_helper_(module, init_fn_) @@ -196,8 +171,21 @@ def generic_param_init_fn_( module, '_is_residual', False): with torch.no_grad(): module.weight.div_(div_is_residual) # type: ignore + return True - elif isinstance(module, nn.Embedding): + return False + + +def embedding_init( + module: nn.Module, + init_fn_: Callable, + emb_init_std: Optional[float], + emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]], + **kwargs: Any, +) -> bool: + del kwargs # unused, just to capture any extra args + + if isinstance(module, nn.Embedding): # Embedding if emb_init_std is not None: std = emb_init_std @@ -224,8 +212,19 @@ def generic_param_init_fn_( emb_init_fn_(module.weight) - elif isinstance(module, - tuple(set([norms.get(name) for name in norms.get_all()]))): + return True + + return False + + +def norm_init( + module: nn.Module, + **kwargs: Any, +) -> bool: + del kwargs # unused, just to capture any extra args + + if isinstance(module, + tuple(set([norms.get(name) for name in norms.get_all()]))): # Norm if hasattr(module, 'weight') and isinstance(module.weight, torch.Tensor): @@ -233,7 +232,22 @@ def generic_param_init_fn_( if hasattr(module, 'bias') and isinstance(module.bias, torch.Tensor): torch.nn.init.zeros_(module.bias) - elif isinstance(module, nn.MultiheadAttention): + return True + + return False + + +def multihead_attention_init( + module: nn.Module, + init_fn_: Callable, + d_model: Optional[int], + init_div_is_residual: Union[int, float, str, bool], + div_is_residual: float, + **kwargs: Any, +) -> bool: + del kwargs # unused, just to capture any extra args + + if isinstance(module, nn.MultiheadAttention): # torch's MultiheadAttention if module._qkv_same_embed_dim: assert module.in_proj_weight is not None @@ -268,7 +282,19 @@ def generic_param_init_fn_( if module.out_proj.bias is not None: torch.nn.init.zeros_(module.out_proj.bias) - elif te is not None and isinstance(module, te.LayerNormMLP): + return True + + return False + + +def te_layernorm_mlp_init( + module: nn.Module, + init_fn_: Callable, + **kwargs: Any, +) -> bool: + del kwargs # unused, just to capture any extra args + + if te is not None and isinstance(module, te.LayerNormMLP): if isinstance(module.layer_norm_weight, torch.Tensor): torch.nn.init.ones_(module.layer_norm_weight) if isinstance(module.layer_norm_bias, torch.Tensor): @@ -286,7 +312,19 @@ def generic_param_init_fn_( with torch.no_grad(): module.fc2_weight.div_(div_is_residual) # type: ignore - elif megablocks is not None and isinstance(module, ( + return True + + return False + + +def moe_init( + module: nn.Module, + init_fn_: Callable, + init_div_is_residual: Union[int, float, str, bool], + div_is_residual: float, + **kwargs: Any, +) -> bool: + if megablocks is not None and isinstance(module, ( megablocks.layers.moe.MoE, megablocks.layers.dmoe.dMoE, megablocks.layers.moe.ParallelMLP, @@ -295,32 +333,96 @@ def generic_param_init_fn_( if hasattr(module, 'bias') and module.bias is not None: # Initialize bias to 0 torch.nn.init.zeros_(module.bias) # type: ignore + return True elif megablocks is not None and isinstance(module, megablocks.layers.glu.SparseGLU): _megablocks_sparse_glu_generic_param_init_fn_( module, init_fn_, bool(init_div_is_residual), div_is_residual) + return True elif megablocks is not None and isinstance(module, megablocks.layers.mlp.SparseMLP): _megablocks_sparse_mlp_generic_param_init_fn_( module, init_fn_, bool(init_div_is_residual), div_is_residual) + return True elif megablocks is not None and isinstance(module, megablocks.layers.mlp.MLP): _megablocks_mlp_generic_param_init_fn_(module, init_fn_, bool(init_div_is_residual), div_is_residual) + return True elif isinstance(module, GLU): init_fn_(module.w1) init_fn_(module.v1) init_fn_(module.w2) + return True elif isinstance(module, MLP): init_fn_(module.w1) init_fn_(module.w2) + return True + + return False + + +def generic_param_init_fn_( + module: nn.Module, + init_fn_: Callable, + n_layers: int, + d_model: Optional[int] = None, + init_div_is_residual: Union[int, float, str, bool] = True, + emb_init_std: Optional[float] = None, + emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, + **kwargs: Any, +) -> None: + del kwargs # unused, just to capture any extra args from the config + # enable user to divide _is_residual weights by + + # a value which defaults to math.sqrt(2 * cfg.n_layers) + init_div_is_residual = init_div_is_residual + + if init_div_is_residual is False: + # not used, for pyright + div_is_residual = 1.0 + elif init_div_is_residual is True: + div_is_residual = math.sqrt(2 * n_layers) + elif isinstance(init_div_is_residual, float) or isinstance( + init_div_is_residual, int): + div_is_residual = init_div_is_residual + elif init_div_is_residual.isnumeric(): + # do not trust YAML parsing to always convert numbers to numbers + div_is_residual = float(init_div_is_residual) else: + # not used, for pyright + div_is_residual = 1.0 + raise ValueError( + f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}' + ) + + all_module_init_fns = [ + module_init_fns.get(name) for name in module_init_fns.get_all() + ] + did_init = False + for module_init_fn in all_module_init_fns: + did_init = module_init_fn( + module=module, + init_fn_=init_fn_, + d_model=d_model, + init_div_is_residual=init_div_is_residual, + div_is_residual=div_is_residual, + emb_init_std=emb_init_std, + emb_init_uniform_lim=emb_init_uniform_lim, + ) + + if did_init: + break + + if not did_init: for _ in module.parameters(recurse=False): # raise error if uninitialized module has any parameters raise NotImplementedError( - f'{module.__class__.__name__} parameters are not initialized by param_init_fn.' - ) + f'{module.__class__.__name__} parameters are not initialized by any of the registered module_init_fns. ' + + + 'Please add an appropriate module_init_fn to the registry. Currently registered module_init_fns are: ' + + ', '.join(module_init_fns.get_all())) def _megablocks_sparse_mlp_generic_param_init_fn_( @@ -726,13 +828,18 @@ def xavier_normal_param_init_fn_( ) -MODEL_INIT_REGISTRY = { - 'default_': torch_default_param_init_fn_, - 'baseline_': baseline_param_init_fn_, - 'kaiming_uniform_': kaiming_uniform_param_init_fn_, - 'kaiming_normal_': kaiming_normal_param_init_fn_, - 'neox_init_': neox_param_init_fn_, - 'small_init_': small_param_init_fn_, - 'xavier_uniform_': xavier_uniform_param_init_fn_, - 'xavier_normal_': xavier_normal_param_init_fn_, -} +param_init_fns.register('default_', func=torch_default_param_init_fn_) +param_init_fns.register('baseline_', func=baseline_param_init_fn_) +param_init_fns.register('kaiming_uniform_', func=kaiming_uniform_param_init_fn_) +param_init_fns.register('kaiming_normal_', func=kaiming_normal_param_init_fn_) +param_init_fns.register('neox_init_', func=neox_param_init_fn_) +param_init_fns.register('small_init_', func=small_param_init_fn_) +param_init_fns.register('xavier_uniform_', func=xavier_uniform_param_init_fn_) +param_init_fns.register('xavier_normal_', func=xavier_normal_param_init_fn_) + +module_init_fns.register('fc', func=fc_init) +module_init_fns.register('embedding', func=embedding_init) +module_init_fns.register('norm', func=norm_init) +module_init_fns.register('multihead_attention', func=multihead_attention_init) +module_init_fns.register('te_layernorm_mlp', func=te_layernorm_mlp_init) +module_init_fns.register('moe', func=moe_init) diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index 424075da3b..6e1824ea08 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -12,7 +12,10 @@ from transformers import PreTrainedTokenizerBase from llmfoundry.interfaces import CallbackWithConfig -from llmfoundry.layers_registry import norms +from llmfoundry.layers_registry import (attention_classes, + attention_implementations, fcs, ffns, + ffns_with_megablocks, ffns_with_norm, + module_init_fns, norms, param_init_fns) from llmfoundry.utils.registry_utils import create_registry _loggers_description = ( @@ -85,17 +88,24 @@ entry_points=True, description=_schedulers_description) -_models_description = """The models registry is used to register classes that implement the ComposerModel interface. The model -constructor should accept two arguments: an omegaconf DictConfig named `om_model_config` and a PreTrainedTokenizerBase named `tokenizer`. -Note: This will soon be updated to take in named kwargs instead of a config directly.""" +_models_description = ( + 'The models registry is used to register classes that implement the ComposerModel interface. ' + + + 'The model constructor should accept two arguments: an omegaconf DictConfig named `om_model_config` and a PreTrainedTokenizerBase named `tokenizer`. ' + + + 'Note: This will soon be updated to take in named kwargs instead of a config directly.' +) models = create_registry('llmfoundry', 'models', generic_type=Type[ComposerModel], entry_points=True, description=_models_description) -_dataloaders_description = """The dataloaders registry is used to register functions that create a DataSpec. The function should take -a DictConfig, a PreTrainedTokenizerBase, and an int as arguments, and return a DataSpec.""" +_dataloaders_description = ( + 'The dataloaders registry is used to register functions that create a DataSpec. The function should take ' + + + 'a DictConfig, a PreTrainedTokenizerBase, and an int as arguments, and return a DataSpec.' +) dataloaders = create_registry( 'llmfoundry', 'dataloaders', @@ -103,7 +113,9 @@ entry_points=True, description=_dataloaders_description) -_metrics_description = """The metrics registry is used to register classes that implement the torchmetrics.Metric interface.""" +_metrics_description = ( + 'The metrics registry is used to register classes that implement the torchmetrics.Metric interface.' +) metrics = create_registry('llmfoundry', 'metrics', generic_type=Type[Metric], @@ -121,4 +133,12 @@ 'metrics', 'dataloaders', 'norms', + 'param_init_fns', + 'module_init_fns', + 'ffns', + 'ffns_with_norm', + 'ffns_with_megablocks', + 'attention_classes', + 'attention_implementations', + 'fcs', ] diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 5cdbb4ee62..2c9c3d6ac2 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -6,14 +6,13 @@ import logging import os import re +import warnings from collections import OrderedDict from typing import (Any, ContextManager, Dict, Iterable, List, Optional, Tuple, Union) import torch from composer.core import Algorithm, Callback, Evaluator -from composer.datasets.in_context_learning_evaluation import \ - get_icl_task_dataloader from composer.loggers import LoggerDestination from composer.models import ComposerModel from composer.optim.scheduler import ComposerScheduler @@ -27,8 +26,11 @@ from llmfoundry import registry from llmfoundry.callbacks import EvalGauntlet from llmfoundry.data.dataloader import build_dataloader +from llmfoundry.eval.datasets.in_context_learning_evaluation import \ + get_icl_task_dataloader from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper from llmfoundry.utils.registry_utils import construct_from_registry +from llmfoundry.utils.warnings import VersionedDeprecationWarning log = logging.getLogger(__name__) @@ -498,8 +500,15 @@ def _validate_cfg(icl_cfg: DictConfig): icl_cfg.metric_names = [ 'InContextLearningMultipleChoiceAccuracy' ] - elif icl_cfg.icl_task_type == 'question_answering': - icl_cfg.metric_names = ['InContextLearningQAAccuracy'] + elif icl_cfg.icl_task_type == 'generation_task_with_answers' or icl_cfg.icl_task_type == 'question_answering': + if icl_cfg.icl_task_type == 'question_answering': + warnings.warn( + VersionedDeprecationWarning( + "ICL task type 'question_answering' is now deprecated. Use identifier 'generation_task_with_answers'", + 'v0.9.0')) + icl_cfg.metric_names = [ + 'InContextLearningGenerationExactMatchAccuracy' + ] elif icl_cfg.icl_task_type == 'code_evaluation': icl_cfg.metric_names = ['InContextLearningCodeEvalAccuracy'] else: diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 4bcb196f2e..16ae1aafee 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -11,6 +11,7 @@ from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om +from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.models.utils import init_empty_weights log = logging.getLogger(__name__) @@ -141,7 +142,7 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): # Set ffn_config.device_mesh to fsdp_config.device_mesh if fsdp_config is not None and 'device_mesh' in fsdp_config and 'ffn_config' in model_cfg and model_cfg[ - 'ffn_config'].get('ffn_type', None) in {'mb_moe', 'mb_dmoe'}: + 'ffn_config'].get('ffn_type', None) in ffns_with_megablocks: # Raise ValueError if not using device mesh with MoE expert parallelism if fsdp_config['device_mesh'] is None and model_cfg['ffn_config'].get( 'moe_world_size', 1) > 1: diff --git a/llmfoundry/utils/huggingface_hub_utils.py b/llmfoundry/utils/huggingface_hub_utils.py index 5a198bc8df..3903a9bed3 100644 --- a/llmfoundry/utils/huggingface_hub_utils.py +++ b/llmfoundry/utils/huggingface_hub_utils.py @@ -132,7 +132,8 @@ def edit_files_for_hf_compatibility( flatten_imports_prefix: Sequence[str] = ('llmfoundry',), remove_imports_prefix: Sequence[str] = ('composer', 'omegaconf', 'llmfoundry.metrics', - 'llmfoundry.utils.builders'), + 'llmfoundry.eval', + 'llmfoundry.utils.builders') ) -> None: """Edit files to be compatible with Hugging Face Hub. diff --git a/llmfoundry/utils/registry_utils.py b/llmfoundry/utils/registry_utils.py index 0901ea198a..0eeefbae74 100644 --- a/llmfoundry/utils/registry_utils.py +++ b/llmfoundry/utils/registry_utils.py @@ -1,9 +1,11 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import copy import functools import importlib.util import os +from contextlib import contextmanager from pathlib import Path from types import ModuleType from typing import (Any, Callable, Dict, Generic, Optional, Sequence, Type, @@ -15,6 +17,7 @@ T = TypeVar('T') TypeBoundT = TypeVar('TypeBoundT', bound=Type) +CallableBoundT = TypeVar('CallableBoundT', bound=Callable[..., Any]) class TypedRegistry(catalogue.Registry, Generic[T]): @@ -142,7 +145,7 @@ def construct_from_registry( ) if post_validation_function is not None: - post_validation_function(registered_constructor) + post_validation_function(constructed_item) return constructed_item @@ -173,3 +176,13 @@ def import_file(loc: Union[str, Path]) -> ModuleType: except Exception as e: raise RuntimeError(f'Error executing {loc}') from e return module + + +@contextmanager +def save_registry(): + """Save the registry state and restore after the context manager exits.""" + saved_registry_state = copy.deepcopy(catalogue.REGISTRY) + + yield + + catalogue.REGISTRY = saved_registry_state diff --git a/scripts/data_prep/convert_finetuning_dataset.py b/scripts/data_prep/convert_finetuning_dataset.py index 594e4f778f..e78e76a912 100644 --- a/scripts/data_prep/convert_finetuning_dataset.py +++ b/scripts/data_prep/convert_finetuning_dataset.py @@ -269,6 +269,7 @@ def main(args: Namespace) -> None: examples_removed = 0 for sample in tqdm(samples, desc=split_name): formatted_sample = preprocessing_fn(sample) + assert isinstance(formatted_sample, dict) # Use the _get_example_type utility to confirm that the formatted sample # can be interpreted by the tokenization code @@ -300,13 +301,15 @@ def main(args: Namespace) -> None: out.write(sample_to_write) else: if example_type == 'prompt_response': - encoded_sample = { - key: formatted_sample[key].encode('utf-8') - for key in ['prompt', 'response'] - } + encoded_sample = {} + for key in ['prompt', 'response']: + value = formatted_sample[key] + assert isinstance(value, str) + encoded_sample[key] = value.encode('utf-8') + out.write(encoded_sample) else: - encoded_sample = formatted_sample - out.write(encoded_sample) + out.write(formatted_sample) + if tokenizer is not None and examples_removed > 0: warnings.warn( f'Dropped {examples_removed} examples where the prompt was longer than {args.max_seq_len}, ' diff --git a/scripts/data_prep/convert_text_to_mds.py b/scripts/data_prep/convert_text_to_mds.py index df39e38a90..be986fc24d 100644 --- a/scripts/data_prep/convert_text_to_mds.py +++ b/scripts/data_prep/convert_text_to_mds.py @@ -114,6 +114,13 @@ def parse_args() -> Namespace: help='If true, reprocess the input_folder to mds format. Otherwise, ' + 'only reprocess upon changes to the input folder or dataset creation parameters.', ) + parser.add_argument( + '--trust-remote-code', + type=bool, + required=False, + default=False, + help='If true, allows custom code to be executed to load the tokenizer', + ) parsed = parser.parse_args() @@ -124,7 +131,8 @@ def parse_args() -> Namespace: parser.error( 'Cannot set --eos_text with --use_tokenizer_eos. Please specify one.' ) - tokenizer = AutoTokenizer.from_pretrained(parsed.tokenizer) + tokenizer = AutoTokenizer.from_pretrained( + parsed.tokenizer, trust_remote_code=parsed.trust_remote_code) parsed.eos_text = tokenizer.eos_token # now that we have validated them, change BOS/EOS to strings @@ -171,6 +179,7 @@ def get_task_args( bos_text: str, no_wrap: bool, compression: str, + trust_remote_code: bool, ) -> Iterable: """Get download_and_convert arguments split across n_groups. @@ -187,6 +196,7 @@ def get_task_args( bos_text (str): Text to prepend to each example to separate concatenated samples no_wrap: (bool): Whether to let text examples wrap across multiple training examples compression (str): The compression algorithm to use for MDS writing + trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer """ num_objects = len(object_names) objs_per_group = math.ceil(num_objects / n_groups) @@ -202,6 +212,7 @@ def get_task_args( bos_text, no_wrap, compression, + trust_remote_code, ) @@ -223,6 +234,7 @@ def download_and_convert( bos_text: str, no_wrap: bool, compression: str, + trust_remote_code: bool, ): """Downloads and converts text fies to MDS format. @@ -236,6 +248,7 @@ def download_and_convert( bos_text (str): Text to prepend to each example to separate concatenated samples no_wrap: (bool): Whether to let text examples wrap across multiple training examples compression (str): The compression algorithm to use for MDS writing + trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer """ object_store = maybe_create_object_store_from_uri(input_folder) @@ -244,7 +257,8 @@ def download_and_convert( downloading_iter = DownloadingIterable(object_names=file_names, output_folder=tmp_dir, object_store=object_store) - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, trust_remote_code=trust_remote_code) tokenizer.model_max_length = 5000000000 # Hack to prevent warnings from HuggingFace # Use the ConcatTokensDataset from LLM-foundry to concatenate sequences of tokens up @@ -353,6 +367,7 @@ def convert_text_to_mds( processes: int, args_str: str, reprocess: bool, + trust_remote_code: bool, ): """Convert a folder of text files to MDS format. @@ -368,6 +383,7 @@ def convert_text_to_mds( processes (int): The number of processes to use. args_str (str): String representation of the arguments reprocess (bool): Whether to always reprocess the given folder of text files + trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer """ is_remote_output = is_remote_path(output_folder) @@ -396,7 +412,7 @@ def convert_text_to_mds( # Download and convert the text files in parallel args = get_task_args(object_names, local_output_folder, input_folder, processes, tokenizer_name, concat_tokens, eos_text, - bos_text, no_wrap, compression) + bos_text, no_wrap, compression, trust_remote_code) with ProcessPoolExecutor(max_workers=processes) as executor: list(executor.map(download_and_convert_starargs, args)) @@ -405,7 +421,7 @@ def convert_text_to_mds( else: download_and_convert(object_names, local_output_folder, input_folder, tokenizer_name, concat_tokens, eos_text, bos_text, - no_wrap, compression) + no_wrap, compression, trust_remote_code) # Write a done file with the args and object names write_done_file(local_output_folder, args_str, object_names) @@ -462,6 +478,7 @@ def _args_str(original_args: Namespace) -> str: compression=args.compression, processes=args.processes, reprocess=args.reprocess, + trust_remote_code=args.trust_remote_code, args_str=_args_str(args)) except Exception as e: if mosaicml_logger is not None: diff --git a/scripts/eval/README.md b/scripts/eval/README.md index 1e6547facf..ba2e9e2c79 100644 --- a/scripts/eval/README.md +++ b/scripts/eval/README.md @@ -145,7 +145,8 @@ You can use the default `icl_tasks` and `eval_gauntlet` configs or specify your ICL evaluation measures a model’s ability to solve novel problems by being provided examples in-context without ever being specifically trained to answer such questions. -Composer supports a number of different standard ICL formats and allows users to upload their own datasets that correspond to those formats. +We support a number of standard ICL formats and allow users to upload their own datasets that correspond to these formats. All of our ICL task types are implemented in `llm-foundry/llmfoundry/eval/datasets/in_context_learning_evaluation.py` while all of our ICL +metrics are implemented in `llm-foundry/llmfoundry/eval/metrics/nlp.py`. You can see which metrics work with which task types in the `llmfoundry.utils.builders.build_icl_evaluators` helper function. This document explains the ICL formats compatible with [Composer](https://github.com/mosaicml/composer), summarizes how to add new datasets in those formats, and catalogs the datasets currently used by the research team to evaluate models. @@ -153,19 +154,19 @@ This document explains the ICL formats compatible with [Composer](https://github ## Supported ICL formats -Composer currently supports five ICL formats: +llm-foundry currently supports five ICL formats: -1. [InContextLearningQATaskDataset](https://github.com/mosaicml/composer/blob/336bf8db3e2c09ae942d4bf8a819935106589d1a/composer/datasets/in_context_learning_evaluation.py#L103) -2. [InContextLearningLMTaskDataset](https://github.com/mosaicml/composer/blob/336bf8db3e2c09ae942d4bf8a819935106589d1a/composer/datasets/in_context_learning_evaluation.py#L293) -3. [InContextLearningMultipleChoiceTaskDataset](https://github.com/mosaicml/composer/blob/336bf8db3e2c09ae942d4bf8a819935106589d1a/composer/datasets/in_context_learning_evaluation.py#L444) -4. [InContextLearningSchemaTaskDataset](https://github.com/mosaicml/composer/blob/336bf8db3e2c09ae942d4bf8a819935106589d1a/composer/datasets/in_context_learning_evaluation.py#L676) -5. [InContextLearningCodeEvalDataset](https://github.com/mosaicml/composer/blob/336bf8db3e2c09ae942d4bf8a819935106589d1a/composer/datasets/in_context_learning_evaluation.py#L852) +1. InContextLearningGenerationTaskWithAnswersDataset +2. InContextLearningLMTaskDataset +3. InContextLearningMultipleChoiceTaskDataset +4. InContextLearningSchemaTaskDataset +5. InContextLearningCodeEvalDataset ---- -### InContextLearningQATaskDataset +### InContextLearningGenerationTaskWithAnswersDataset -The ICL question answering (QA) task supports free response question answering evaluation using the model’s generate function. A QA dataset consists of a list of JSONs containing a question (under the key `context`), a correct answer (under the key `answer`), and a list of alternative spellings of the answer that would be considered permissible (under the key `aliases`). The QA task works with the NLP metric: [InContextLearningQAAccuracy](https://docs.mosaicml.com/projects/composer/en/latest/api_reference/generated/composer.metrics.InContextLearningQAAccuracy.html) which assigns a model's output to be "correct" if, conditioned on the context, the model's generate method produces a string that is a normalized prefix for either the `answer` or any of the `aliases`. +The ICL generation with answers task supports free response generation evaluation using the model’s generate function. A generation dataset consists of a list of JSONs containing a prompt (under the key `context`), a correct answer (under the key `answer`), and a list of alternative answers that would be considered permissible (under the key `aliases`). The generation task works with the NLP metric: InContextLearningGenerationExactMatchAccuracy which assigns a model's output to be "correct" if, conditioned on the context, the model's generate method produces a string that is a normalized prefix for either the `answer` or any of the `aliases`. Required keys for each datum: * `context`: str @@ -178,7 +179,7 @@ An example datum is below: {"context": "What star sign is Jamie Lee Curtis?", "answer": "Scorpio", "aliases": ["Scorpio", "Skorpio"]} ``` -The QA task expects a **prompt string**, a **continuation delimiter** to separate questions from answers, an **example delimiter** to separate few shot examples from one another, and a **question prelimiter** to put before each question. If using the following settings, with 2 examples in context, the above datum may be rendered to the model as: +The generation task expects a **prompt string**, a **continuation delimiter** to separate questions from answers, an **example delimiter** to separate few shot examples from one another, and a **question prelimiter** to put before each question. If using the following settings, with 2 examples in context, the above datum may be rendered to the model as: ```jsx prompt_string: "Answer the following trivia question:\n", example_delimiter: "\n", continuation_delimiter: " Answer: ", question_prelimiter: "Question: " @@ -203,9 +204,9 @@ Below is a complete YAML section that works with the TriviaQA dataset in [`scrip - 5 - 10 batch_size: 4 - icl_task_type: question_answering + icl_task_type: generation_task_with_answers metric_names: - - InContextLearningQAAccuracy + - InContextLearningGenerationExactMatchAccuracy prompt_string: '' # this goes at the beginning of each input example_delimiter: "\n" # this goes between fewshot examples continuation_delimiter: ' ' # this separates questions from answers @@ -215,7 +216,7 @@ Below is a complete YAML section that works with the TriviaQA dataset in [`scrip ### InContextLearningLMTaskDataset -The ICL language modeling (LM) task assesses the model’s ability to predict a precise sequence of tokens (called a continuation) following some context using the model’s `forward` function. An LM dataset consists of a list of JSONs containing a context (under the key `context`) and a continuation (under the key `continuation` that the model must correctly predict conditioned on the context. The LM task uses the NLP metric [InContextLearningLMAccuracy](https://docs.mosaicml.com/projects/composer/en/latest/api_reference/generated/composer.metrics.InContextLearningLMAccuracy.html), which assigns a model's output to be "correct" if, conditioned on the context tokens, the model's argmax output logits exactly match the tokens in the continuation. +The ICL language modeling (LM) task assesses the model’s ability to predict a precise sequence of tokens (called a continuation) following some context using the model’s `forward` function. An LM dataset consists of a list of JSONs containing a context (under the key `context`) and a continuation (under the key `continuation` that the model must correctly predict conditioned on the context. The LM task uses the NLP metric InContextLearningLMAccuracy, which assigns a model's output to be "correct" if, conditioned on the context tokens, the model's argmax output logits exactly match the tokens in the continuation. Required keys for each datum: * `context`: str @@ -256,7 +257,7 @@ Below is a YAML section that works with the Lambada OpenAI dataset in [`scripts/ ### InContextLearningMultipleChoiceTaskDataset -The ICL multiple choice (MC) task assesses the model’s ability to answer multiple choice questions by assigning highest per token probability to the correct answer. An MC dataset consists of a list of JSONs containing a query (under the key `query`), a list of choices (under the key `choices`), and the index indicating the correct answer (under the key `gold`). The MC task works with the NLP metric [InContextLearningMultipleChoiceAccuracy](https://docs.mosaicml.com/projects/composer/en/latest/api_reference/generated/composer.metrics.InContextLearningMultipleChoiceAccuracy.html), which separately runs the model's `forward()` method on the query prepended to each choice, and then determines the model to be correct if the correct choice has the lowest per token perplexity conditioned on the query. +The ICL multiple choice (MC) task assesses the model’s ability to answer multiple choice questions by assigning highest per token probability to the correct answer. An MC dataset consists of a list of JSONs containing a query (under the key `query`), a list of choices (under the key `choices`), and the index indicating the correct answer (under the key `gold`). The MC task works with the NLP metric InContextLearningMultipleChoiceAccuracy, which separately runs the model's `forward()` method on the query prepended to each choice, and then determines the model to be correct if the correct choice has the lowest per token perplexity conditioned on the query. Required keys for each datum: * `query`: str @@ -294,7 +295,6 @@ Below is a YAML section that works with the HellaSwag dataset in [`scripts/eval/ icl_task_type: multiple_choice metric_names: - InContextLearningMultipleChoiceAccuracy - - InContextLearningMCExpectedCalibrationError prompt_string: '' # this goes at the beginning of each input example_delimiter: "\n" # this goes between fewshot examples continuation_delimiter: ' ' # this separates questions from answers @@ -306,7 +306,7 @@ Below is a YAML section that works with the HellaSwag dataset in [`scripts/eval/ The ICL schema task assesses the model’s ability to determine which of some set of possible contexts (under the key `context_options`) makes a sequence of tokens (under the key `continuation`) most likely, with the correct context indicated by "gold". This task is based on [A Simple Method for Commonsense Reasoning](https://arxiv.org/abs/1806.02847). -The schema task works with the NLP metric [InContextLearningMultipleChoiceAccuracy](https://docs.mosaicml.com/projects/composer/en/latest/api_reference/generated/composer.metrics.InContextLearningMultipleChoiceAccuracy.html), which separately runs the model's `forward()` method on each context option prepended to the continuation and rates the model correct if it assigns minimum per token perplexity to the continuation conditioned on the true context. +The schema task works with the NLP metric InContextLearningMultipleChoiceAccuracy, which separately runs the model's `forward()` method on each context option prepended to the continuation and rates the model correct if it assigns minimum per token perplexity to the continuation conditioned on the true context. Required keys for each datum: * query: str diff --git a/scripts/eval/yamls/tasks_v0.1.yaml b/scripts/eval/yamls/tasks_v0.1.yaml index 44f031ae3a..6546b13dd7 100644 --- a/scripts/eval/yamls/tasks_v0.1.yaml +++ b/scripts/eval/yamls/tasks_v0.1.yaml @@ -10,12 +10,12 @@ icl_tasks: label: triviaqa_sm_sub dataset_uri: eval/local_data/world_knowledge/triviaqa_sm_sub.jsonl num_fewshot: [3] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers - label: gsm8k dataset_uri: eval/local_data/symbolic_problem_solving/gsm8k.jsonl num_fewshot: [3] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers cot_delimiter: " #### " continuation_delimiter: "\nA: Let's think step by step. " question_prelimiter: "Q: " @@ -23,21 +23,21 @@ icl_tasks: label: agi_eval_sat_math dataset_uri: eval/local_data/symbolic_problem_solving/agi_eval_sat_math.jsonl num_fewshot: [3] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers cot_delimiter: " #### " continuation_delimiter: "\nA: Let's think step by step. " - label: aqua dataset_uri: eval/local_data/symbolic_problem_solving/aqua.jsonl num_fewshot: [3] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers cot_delimiter: " #### " continuation_delimiter: "\nA: Let's think step by step. " - label: svamp dataset_uri: eval/local_data/symbolic_problem_solving/svamp.jsonl num_fewshot: [3] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers continuation_delimiter: "\nUsing the formula below:\n" cot_delimiter: " #### " question_prelimiter: "Q: " diff --git a/scripts/eval/yamls/tasks_v0.2.yaml b/scripts/eval/yamls/tasks_v0.2.yaml index e23b4df1a5..ae39d87fbd 100644 --- a/scripts/eval/yamls/tasks_v0.2.yaml +++ b/scripts/eval/yamls/tasks_v0.2.yaml @@ -10,12 +10,12 @@ icl_tasks: label: triviaqa_sm_sub dataset_uri: eval/local_data/world_knowledge/triviaqa_sm_sub.jsonl num_fewshot: [3] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers - label: gsm8k dataset_uri: eval/local_data/symbolic_problem_solving/gsm8k.jsonl num_fewshot: [8, 5] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers cot_delimiter: " #### " continuation_delimiter: "\nA: Let's think step by step. " question_prelimiter: "Q: " @@ -23,21 +23,21 @@ icl_tasks: label: agi_eval_sat_math dataset_uri: eval/local_data/symbolic_problem_solving/agi_eval_sat_math.jsonl num_fewshot: [3] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers cot_delimiter: " #### " continuation_delimiter: "\nA: Let's think step by step. " - label: aqua dataset_uri: eval/local_data/symbolic_problem_solving/aqua.jsonl num_fewshot: [3] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers cot_delimiter: " #### " continuation_delimiter: "\nA: Let's think step by step. " - label: svamp dataset_uri: eval/local_data/symbolic_problem_solving/svamp.jsonl num_fewshot: [5] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers continuation_delimiter: "\nUsing the formula below:\n" cot_delimiter: " #### " question_prelimiter: "Q: " diff --git a/scripts/eval/yamls/tasks_v0.3.yaml b/scripts/eval/yamls/tasks_v0.3.yaml index e02178710e..396ceaaf85 100644 --- a/scripts/eval/yamls/tasks_v0.3.yaml +++ b/scripts/eval/yamls/tasks_v0.3.yaml @@ -3,7 +3,7 @@ icl_tasks: label: gsm8k dataset_uri: eval/local_data/symbolic_problem_solving/gsm8k_prepended_8shot.jsonl num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers cot_delimiter: "The answer is " continuation_delimiter: "\n\nA:" question_prelimiter: "" @@ -15,13 +15,13 @@ icl_tasks: label: triviaqa_sm_sub dataset_uri: eval/local_data/world_knowledge/triviaqa_sm_sub.jsonl num_fewshot: [3] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers do_normalization: true - label: svamp dataset_uri: eval/local_data/symbolic_problem_solving/svamp.jsonl num_fewshot: [5] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers cot_delimiter: "The answer is " continuation_delimiter: "\n\nA:" question_prelimiter: "Question: " diff --git a/scripts/train/train.py b/scripts/train/train.py index d256796a3d..0fcf77f78b 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -13,7 +13,6 @@ import torch from composer import Trainer from composer.core.callback import Callback -from composer.metrics.nlp import InContextLearningMetric from composer.profiler import (JSONTraceHandler, Profiler, TraceHandler, cyclic_schedule) from composer.utils import dist, get_device, reproducibility @@ -22,6 +21,7 @@ from omegaconf import OmegaConf as om from rich.traceback import install +from llmfoundry.eval.metrics.nlp import InContextLearningMetric from llmfoundry.utils import (find_mosaicml_logger, log_train_analytics, maybe_create_mosaicml_logger) @@ -30,6 +30,7 @@ from llmfoundry.callbacks import AsyncEval from llmfoundry.data.dataloader import build_dataloader +from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.utils.builders import (add_metrics_to_eval_loaders, build_algorithm, build_callback, build_composer_model, build_evaluators, @@ -156,14 +157,14 @@ def validate_config(cfg: TrainConfig): act_ckpt = fsdp_config.get('activation_checkpointing', False) if fsdp_config else False act_ckpt_reentrant = fsdp_config.get( - 'activation_checkpointing_reentrant', True) if fsdp_config else True - if fsdp_config is not None and act_ckpt == True and act_ckpt_reentrant == False and cfg.fsdp_config is not None: + 'activation_checkpointing_reentrant', False) + if fsdp_config is not None and act_ckpt == True and act_ckpt_reentrant == True: warnings.warn( '`te.Linear` layers do not support activation_checkpointing with ' - + '`activation_checkpointing_reentrant = False`. ' + - 'Setting cfg.fsdp_config.activation_checkpointing_reentrant=True.' + + '`activation_checkpointing_reentrant = True`. ' + + 'Setting cfg.fsdp_config.activation_checkpointing_reentrant=False.' ) - cfg.fsdp_config['activation_checkpointing_reentrant'] = True + cfg.fsdp_config.activation_checkpointing_reentrant = False if cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') == 'te_ln_mlp': warnings.warn( @@ -178,7 +179,7 @@ def validate_config(cfg: TrainConfig): ) if cfg.model.get('ffn_config', {}).get('ffn_type', - 'mptmlp') in ('mb_moe', 'mb_dmoe'): + 'mptmlp') in ffns_with_megablocks: moe_world_size = cfg.model.get('ffn_config', {}).get('moe_world_size', 1) use_orig_params = cfg.fsdp_config.get( diff --git a/scripts/train/yamls/finetune/dbrx-full-ft.yaml b/scripts/train/yamls/finetune/dbrx-full-ft.yaml index a0e2504787..9cb53e40fd 100644 --- a/scripts/train/yamls/finetune/dbrx-full-ft.yaml +++ b/scripts/train/yamls/finetune/dbrx-full-ft.yaml @@ -86,7 +86,6 @@ seed: 17 device_train_microbatch_size: 1 device_eval_batch_size: 1 precision: amp_bf16 -autoresume: true dist_timeout: 3600 expandable_segments: true diff --git a/scripts/train/yamls/finetune/dbrx-lora-ft.yaml b/scripts/train/yamls/finetune/dbrx-lora-ft.yaml index 7fb921ae16..06e8f1d6f0 100644 --- a/scripts/train/yamls/finetune/dbrx-lora-ft.yaml +++ b/scripts/train/yamls/finetune/dbrx-lora-ft.yaml @@ -94,7 +94,6 @@ seed: 17 device_train_microbatch_size: 1 device_eval_batch_size: 1 precision: amp_bf16 -autoresume: true dist_timeout: 3600 expandable_segments: true diff --git a/tests/a_scripts/data_prep/test_convert_text_to_mds.py b/tests/a_scripts/data_prep/test_convert_text_to_mds.py index e458cb1dfc..bd96de695c 100644 --- a/tests/a_scripts/data_prep/test_convert_text_to_mds.py +++ b/tests/a_scripts/data_prep/test_convert_text_to_mds.py @@ -106,6 +106,7 @@ def call_convert_text_to_mds() -> None: processes=processes, args_str='Namespace()', reprocess=False, + trust_remote_code=False, ) call_convert_text_to_mds() @@ -195,6 +196,7 @@ def call_convert_text_to_mds(reprocess: bool): processes=1, args_str='Namespace()', reprocess=reprocess, + trust_remote_code=False, ) # Create input text data @@ -234,6 +236,7 @@ def test_input_folder_not_exist(tmp_path: pathlib.Path): processes=1, args_str='Namespace()', reprocess=False, + trust_remote_code=False, ) diff --git a/tests/a_scripts/train/test_train.py b/tests/a_scripts/train/test_train.py index 3878b22704..7899eeda0a 100644 --- a/tests/a_scripts/train/test_train.py +++ b/tests/a_scripts/train/test_train.py @@ -148,6 +148,7 @@ def test_train_multi_eval(tmp_path: pathlib.Path): tuple) +@pytest.mark.gpu def test_validate_config(): conf_path: str = os.path.join( REPO_DIR, diff --git a/tests/callbacks/test_eval_gauntlet_callback.py b/tests/callbacks/test_eval_gauntlet_callback.py index 3a1e371ab8..8d9938e3a1 100644 --- a/tests/callbacks/test_eval_gauntlet_callback.py +++ b/tests/callbacks/test_eval_gauntlet_callback.py @@ -9,9 +9,9 @@ import torch from composer.core import State from composer.loggers import InMemoryLogger, Logger -from composer.metrics import InContextLearningLMAccuracy from transformers import AutoTokenizer +from llmfoundry.eval.metrics.nlp import InContextLearningLMAccuracy from llmfoundry.utils.builders import build_icl_data_and_gauntlet diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index f5c2631fa7..c99ae6baf2 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -1093,3 +1093,87 @@ def test_build_unknown_dataloader(): tokenizer = MagicMock() with pytest.raises(catalogue.RegistryError): _ = build_dataloader(cfg, tokenizer, 2) + + +invalid_conversation_params_sharegpt = [ + 'add_invalid_last_chat_message', 'add_invalid_content_type', + 'add_invalid_role', 'add_not_alternating_roles' +] + + +@pytest.mark.parametrize( + ','.join(invalid_conversation_params_sharegpt), + generate_exclusive_test_params(invalid_conversation_params_sharegpt)) +def test_sharegpt_format(tmp_path: pathlib.Path, + add_invalid_last_chat_message: bool, + add_invalid_content_type: bool, add_invalid_role: bool, + add_not_alternating_roles: bool): + tokenizer_name = 'mosaicml/mpt-7b' + max_seq_len = 2048 + dataset_size = 5 + device_batch_size = 5 + tiny_dataset_folder_path = tmp_path + tiny_dataset_path = str(tiny_dataset_folder_path / 'train.jsonl') + + tokenizer = build_tokenizer( + tokenizer_name=tokenizer_name, + tokenizer_kwargs={'model_max_length': max_seq_len}, + ) + tokenizer.add_special_tokens({ + 'pad_token': '', + 'bos_token': '', + 'eos_token': '', + }) + + if dist.get_global_rank() == 0: + make_tiny_conversation_ft_dataset( + path=tiny_dataset_path, + size=dataset_size, + add_invalid_last_chat_message=add_invalid_last_chat_message, + add_invalid_message_key_quantity=False, + add_invalid_content_type=add_invalid_content_type, + add_invalid_role=add_invalid_role, + add_not_alternating_roles=add_not_alternating_roles, + use_messages_format=False, + ) + + cfg = { + 'name': 'finetuning', + 'dataset': { + 'hf_name': str(tiny_dataset_folder_path), + 'preprocessing_fn': 'teknium/OpenHermes-2.5', + 'split': 'train', + 'max_seq_len': max_seq_len, + 'decoder_only_format': True, + 'allow_pad_trimming': False, + 'packing_ratio': None, + 'shuffle': True, + }, + 'drop_last': False, + 'num_workers': 0, + 'prefetch_factor': None, + 'pin_memory': False, + 'persistent_workers': False, + 'timeout': 0 + } + + cfg = om.create(cfg) + + error_context = contextlib.nullcontext() + if add_invalid_last_chat_message: + error_context = pytest.raises(InvalidLastChatMessageRoleError, + match='Invalid last message role:') + if add_invalid_content_type: + error_context = pytest.raises(InvalidContentTypeError, + match='Expected content to be') + if add_invalid_role: + error_context = pytest.raises(InvalidRoleError, + match='Expected role to be one of') + + if add_not_alternating_roles: + error_context = pytest.raises(ConsecutiveRepeatedChatRolesError, + match='Conversation roles must alternate') + + with error_context: + build_finetuning_dataloader(cfg, tokenizer, + device_batch_size).dataloader diff --git a/tests/data/test_tasks.yaml b/tests/data/test_tasks.yaml index cec7984320..cf02ffcbbb 100644 --- a/tests/data/test_tasks.yaml +++ b/tests/data/test_tasks.yaml @@ -20,4 +20,4 @@ icl_tasks: label: triviaqa dataset_uri: scripts/eval/local_data/world_knowledge/triviaqa_small.jsonl # ADD YOUR OWN DATASET URI num_fewshot: [0, 1] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers diff --git a/tests/data/test_template_tokenization.py b/tests/data/test_template_tokenization.py index a45c4d8f0d..632a79dac9 100644 --- a/tests/data/test_template_tokenization.py +++ b/tests/data/test_template_tokenization.py @@ -9,6 +9,7 @@ from llmfoundry.data.finetuning.tasks import (_ALLOWED_PROMPT_KEYS, _ALLOWED_RESPONSE_KEYS, _slice_chat_formatted_example, + dataset_constructor, tokenize_formatted_example) from llmfoundry.utils.builders import build_tokenizer @@ -178,34 +179,67 @@ def test_tokenize_instruct_example_well_formed(): @pytest.mark.parametrize( 'tokenizer_name', ['EleutherAI/gpt-neox-20b', 'HuggingFaceH4/zephyr-7b-beta', 't5-base']) -def test_multi_turn_chat_slicing(tokenizer_name: str): - convo = [ - { - 'role': 'system', - 'content': 'everyone thinks you are so cool' - }, - { - 'role': 'user', - 'content': 'hiiii' - }, - { - 'role': 'assistant', - 'content': 'yassss' - }, - { - 'role': 'user', - 'content': 'HIIIIII!!!' - }, - { - 'role': 'assistant', - 'content': 'YASSSSSS' - }, - ] +@pytest.mark.parametrize('messages_format', [True, False]) +def test_multi_turn_chat_slicing(tokenizer_name: str, messages_format: bool): + if messages_format: + convo = [ + { + 'role': 'system', + 'content': 'everyone thinks you are so cool' + }, + { + 'role': 'user', + 'content': 'hiiii' + }, + { + 'role': 'assistant', + 'content': 'yassss' + }, + { + 'role': 'user', + 'content': 'HIIIIII!!!' + }, + { + 'role': 'assistant', + 'content': 'YASSSSSS' + }, + ] + else: + convo = [ + { + 'from': 'system', + 'value': 'everyone thinks you are so cool' + }, + { + 'from': 'human', + 'value': 'hiiii' + }, + { + 'from': 'gpt', + 'value': 'yassss' + }, + { + 'from': 'tool', + 'value': 'HIIIIII!!!' + }, + { + 'from': 'gpt', + 'value': 'YASSSSSS' + }, + ] + tmp = {'conversations': convo} + preprocessor = dataset_constructor.get_preprocessing_fn_from_str( + 'teknium/OpenHermes-2.5') + assert preprocessor is not None + convo = preprocessor(tmp)['messages'] + assert isinstance(convo, list) + + example = {'messages': convo} tok = transformers.AutoTokenizer.from_pretrained(tokenizer_name) templated_prompt_response_turns = _slice_chat_formatted_example( - {'messages': convo}, tok) + example, tok) reconstructed_chat = '' for prompt, response in templated_prompt_response_turns: diff --git a/tests/data_utils.py b/tests/data_utils.py index 3c077b5e71..fd24d4cbbf 100644 --- a/tests/data_utils.py +++ b/tests/data_utils.py @@ -6,7 +6,7 @@ import shutil from argparse import Namespace from pathlib import Path -from typing import Optional +from typing import Dict, List, Optional from omegaconf import DictConfig from omegaconf import OmegaConf as om @@ -99,6 +99,7 @@ def make_tiny_conversation_ft_dataset( add_invalid_content_type: bool = False, add_invalid_role: bool = False, add_not_alternating_roles: bool = False, + use_messages_format: bool = True, ): if Path(path).suffix != '.jsonl': raise ValueError(f'Path {path} must be a jsonl file.') @@ -198,6 +199,24 @@ def make_tiny_conversation_ft_dataset( }] }) + def messages_to_conversation(sample: Dict): + assert 'messages' in sample + messages = sample['messages'] + + role_map = { + 'user': 'human', + 'assistant': 'gpt', + } + conversations: List[Dict[str, str]] = [] + for message in messages: + role: str = role_map.get(message['role'], message['role']) + content: str = message['content'] + conversations.append({'from': role, 'value': content}) + return {'conversations': conversations} + + if not use_messages_format: + samples = [messages_to_conversation(sample) for sample in samples] + os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, 'w') as _f: for sample in samples: diff --git a/tests/eval/local_data/gsm8k_small.jsonl b/tests/eval/local_data/gsm8k_small.jsonl new file mode 100644 index 0000000000..522966c902 --- /dev/null +++ b/tests/eval/local_data/gsm8k_small.jsonl @@ -0,0 +1,4 @@ +{"context": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", "chain_of_thought": "Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.\nShe makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.", "answer": "18"} +{"context": "A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?", "chain_of_thought": "It takes 2/2=<<2/2=1>>1 bolt of white fiber\nSo the total amount of fabric is 2+1=<<2+1=3>>3 bolts of fabric", "answer": "3"} +{"context": "Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs. This increased the value of the house by 150%. How much profit did he make?", "chain_of_thought": "The cost of the house and repairs came out to 80,000+50,000=$<<80000+50000=130000>>130,000\nHe increased the value of the house by 80,000*1.5=<<80000*1.5=120000>>120,000\nSo the new value of the house is 120,000+80,000=$<<120000+80000=200000>>200,000\nSo he made a profit of 200,000-130,000=$<<200000-130000=70000>>70,000", "answer": "70000"} +{"context": "James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week?", "chain_of_thought": "He sprints 3*3=<<3*3=9>>9 times\nSo he runs 9*60=<<9*60=540>>540 meters", "answer": "540"} diff --git a/tests/eval/local_data/hellaswag_small.jsonl b/tests/eval/local_data/hellaswag_small.jsonl new file mode 100644 index 0000000000..d2e37771c9 --- /dev/null +++ b/tests/eval/local_data/hellaswag_small.jsonl @@ -0,0 +1,4 @@ +{"query": "Removing ice from car: Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles. Then", "choices": [", the man adds wax to the windshield and cuts it.", ", a person board a ski lift, while two men supporting the head of the person wearing winter clothes snow as the we girls sled.", ", the man puts on a christmas coat, knitted with netting.", ", the man continues removing the snow on his car."], "gold": 3} +{"query": "Baking cookies: A female chef in white uniform shows a stack of baking pans in a large kitchen presenting them. The pans", "choices": ["contain egg yolks and baking soda.", "are then sprinkled with brown sugar.", "are placed in a strainer on the counter.", "are filled with pastries and loaded into the oven."], "gold": 3} +{"query": "Baking cookies: A female chef in white uniform shows a stack of baking pans in a large kitchen presenting them. The pans are filled with pastries and loaded into the oven. A knife", "choices": ["is seen moving on a board and cutting out its contents.", "hits the peeled cheesecake, followed by sliced custard and still cooked ice cream.", "etches a shape into the inside of the baked pans.", "is used to cut cylinder shaped dough into rounds."], "gold": 3} +{"query": "Baking cookies: A tray of potatoes is loaded into the oven and removed. A large tray of cake is flipped over and placed on counter. A large tray of meat", "choices": ["is placed onto a baked potato.", ", ls, and pickles are placed in the oven.", "is poured into a midden.", "is prepared then it is removed from the oven by a helper when done."], "gold": 3} diff --git a/tests/eval/local_data/human_eval_small.jsonl b/tests/eval/local_data/human_eval_small.jsonl new file mode 100644 index 0000000000..850d46e031 --- /dev/null +++ b/tests/eval/local_data/human_eval_small.jsonl @@ -0,0 +1,4 @@ +{"task_id": "HumanEval/0", "prompt": "from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n \"\"\"\n", "entry_point": "has_close_elements", "canonical_solution": " for idx, elem in enumerate(numbers):\n for idx2, elem2 in enumerate(numbers):\n if idx != idx2:\n distance = abs(elem - elem2)\n if distance < threshold:\n return True\n\n return False\n", "test": "\n\nMETADATA = {\n 'author': 'jt',\n 'dataset': 'test'\n}\n\n\ndef check(candidate):\n assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) == True\n assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05) == False\n assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.95) == True\n assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.8) == False\n assert candidate([1.0, 2.0, 3.0, 4.0, 5.0, 2.0], 0.1) == True\n assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 1.0) == True\n assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 0.5) == False\n\n", "test_inputs": ["([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3)", "([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05)", "([1.0, 2.0, 5.9, 4.0, 5.0], 0.95)", "([1.0, 2.0, 5.9, 4.0, 5.0], 0.8)", "([1.0, 2.0, 3.0, 4.0, 5.0, 2.0], 0.1)", "([1.1, 2.2, 3.1, 4.1, 5.1], 1.0)", "([1.1, 2.2, 3.1, 4.1, 5.1], 0.5)"], "test_outputs": ["True", "False", "True", "False", "True", "True", "False"], "language": "python"} +{"task_id": "HumanEval/1", "prompt": "from typing import List\n\n\ndef separate_paren_groups(paren_string: str) -> List[str]:\n \"\"\" Input to this function is a string containing multiple groups of nested parentheses. Your goal is to\n separate those group into separate strings and return the list of those.\n Separate groups are balanced (each open brace is properly closed) and not nested within each other\n Ignore any spaces in the input string.\n >>> separate_paren_groups('( ) (( )) (( )( ))')\n ['()', '(())', '(()())']\n \"\"\"\n", "entry_point": "separate_paren_groups", "canonical_solution": " result = []\n current_string = []\n current_depth = 0\n\n for c in paren_string:\n if c == '(':\n current_depth += 1\n current_string.append(c)\n elif c == ')':\n current_depth -= 1\n current_string.append(c)\n\n if current_depth == 0:\n result.append(''.join(current_string))\n current_string.clear()\n\n return result\n", "test": "\n\nMETADATA = {\n 'author': 'jt',\n 'dataset': 'test'\n}\n\n\ndef check(candidate):\n assert candidate('(()()) ((())) () ((())()())') == [\n '(()())', '((()))', '()', '((())()())'\n ]\n assert candidate('() (()) ((())) (((())))') == [\n '()', '(())', '((()))', '(((())))'\n ]\n assert candidate('(()(())((())))') == [\n '(()(())((())))'\n ]\n assert candidate('( ) (( )) (( )( ))') == ['()', '(())', '(()())']\n", "test_inputs": ["('(()()) ((())) () ((())()())',)", "('() (()) ((())) (((())))',)", "('(()(())((())))',)", "('( ) (( )) (( )( ))',)"], "test_outputs": ["['(()())', '((()))', '()', '((())()())']", "['()', '(())', '((()))', '(((())))']", "['(()(())((())))']", "['()', '(())', '(()())']"], "language": "python"} +{"task_id": "HumanEval/2", "prompt": "\n\ndef truncate_number(number: float) -> float:\n \"\"\" Given a positive floating point number, it can be decomposed into\n and integer part (largest integer smaller than given number) and decimals\n (leftover part always smaller than 1).\n\n Return the decimal part of the number.\n >>> truncate_number(3.5)\n 0.5\n \"\"\"\n", "entry_point": "truncate_number", "canonical_solution": " return number % 1.0\n", "test": "\n\nMETADATA = {\n 'author': 'jt',\n 'dataset': 'test'\n}\n\n\ndef check(candidate):\n assert candidate(3.5) == 0.5\n assert abs(candidate(1.33) - 0.33) < 1e-6\n assert abs(candidate(123.456) - 0.456) < 1e-6\n", "test_inputs": ["(3.5,)", "(1.33,)", "(123.456,)"], "test_outputs": ["0.5", "0.33000000000000007", "0.45600000000000307"], "language": "python"} +{"task_id": "HumanEval/3", "prompt": "from typing import List\n\n\ndef below_zero(operations: List[int]) -> bool:\n \"\"\" You're given a list of deposit and withdrawal operations on a bank account that starts with\n zero balance. Your task is to detect if at any point the balance of account fallls below zero, and\n at that point function should return True. Otherwise it should return False.\n >>> below_zero([1, 2, 3])\n False\n >>> below_zero([1, 2, -4, 5])\n True\n \"\"\"\n", "entry_point": "below_zero", "canonical_solution": " balance = 0\n\n for op in operations:\n balance += op\n if balance < 0:\n return True\n\n return False\n", "test": "\n\nMETADATA = {\n 'author': 'jt',\n 'dataset': 'test'\n}\n\n\ndef check(candidate):\n assert candidate([]) == False\n assert candidate([1, 2, -3, 1, 2, -3]) == False\n assert candidate([1, 2, -4, 5, 6]) == True\n assert candidate([1, -1, 2, -2, 5, -5, 4, -4]) == False\n assert candidate([1, -1, 2, -2, 5, -5, 4, -5]) == True\n assert candidate([1, -2, 2, -2, 5, -5, 4, -4]) == True\n", "test_inputs": ["([],)", "([1, 2, -3, 1, 2, -3],)", "([1, 2, -4, 5, 6],)", "([1, -1, 2, -2, 5, -5, 4, -4],)", "([1, -1, 2, -2, 5, -5, 4, -5],)", "([1, -2, 2, -2, 5, -5, 4, -4],)"], "test_outputs": ["False", "False", "True", "False", "True", "True"], "language": "python"} diff --git a/tests/eval/local_data/lambada_small.jsonl b/tests/eval/local_data/lambada_small.jsonl new file mode 100644 index 0000000000..5a0dc238ae --- /dev/null +++ b/tests/eval/local_data/lambada_small.jsonl @@ -0,0 +1,4 @@ +{"context": "With Tristran's next step he was standing beside a lake, and the candlelight shone brightly on the water; and then he was walking through the mountains, through lonely crags, where the candlelight was reflected in the eyes of the creatures of the high snows; and then he was walking through the clouds, which, while not entirely substantial, still supported his weight in comfort; and then, holding tightly to his candle, he was underground, and the candlelight glinted back at him from the wet cave walls; now he was in the mountains once more; and then he was on a road through wild forest, and he glimpsed a chariot being pulled by two goats, being driven by a woman in a red dress who looked, for the glimpse he got of her, the way Boadicea was drawn in his history books; and another step and he was in a leafy glen, and he could hear the chuckle of water as it splashed and sang its way into a small brook.\n\nHe took another step, but he was still in the", "continuation": "glen"} +{"context": "Todd replied: No I thought you looked familiar but I can’t recall. The stranger told Todd: I’m Enoch; we met in your dream. Todd looked back again, this time he realized it really was Enoch; Todd stopped on the side of the road, leaned back and tried to see if he was dreaming. When Enoch said: No Todd you’re not", "continuation": "dreaming"} +{"context": "The Librarian thumbed through the bundle of pages, stopping on the final sheet and began reading, “It is our conclusion that much of the work that is currently done in the Library can be out-sourced to contractors, particularly non-skill specific work such as shelving, stacking...”\nLucy gulped and Gillian began to open her mouth to protest again, but the Librarian carried on regardless, his voice becoming louder in order to drown out any potentially dissenting voices, “... blah, blah, blah. It is our recommendation that a downsizing of the non-essential and part-time members of staff would bring instant economy of scale benefits and would allow for the implementation of a new middle management structure.”\n“You mean sacrifice the troops to pay for the generals,” said", "continuation": "Gillian"} +{"context": "He was small, even for a dwarf, and his poor taste in sorcerous robes contrasted awkwardly with D’jebee’s elegant attire; her long, diaphanous gown and his chemical-stained, star-spangled robe clashed almost as much as her vacuous expression alongside his own visage, alive as it was with cunning and a twisted intelligence.\n\nD’jebee sighed with boredom.\n\n‘What is it, my love?’ Poldanyelz oozed with ersatz concern.\n\n‘I’m bored,’ D’jebee complained undiplomatically. ‘No one ever comes here. I never see anyone except you.’\n\nA shuffling from the main arch alerted her to the inaccuracy of her", "continuation": "statement"} diff --git a/tests/eval/local_data/mmlu_small.jsonl b/tests/eval/local_data/mmlu_small.jsonl new file mode 100644 index 0000000000..90eb402607 --- /dev/null +++ b/tests/eval/local_data/mmlu_small.jsonl @@ -0,0 +1,4 @@ +{"query": "Question: How is IP address spoofing detected?\n(A) Installing and configuring a IDS that can read the IP header (B) Comparing the TTL values of the actual and spoofed addresses (C) Implementing a firewall to the network (D) Identify all TCP sessions that are initiated but does not complete successfully\n", "gold": 1, "choices": ["A", "B", "C", "D"], "category": "computer_security"} +{"query": "Question: Which of the following is not an example of presentation layer issues?\n(A) Poor handling of unexpected input can lead to the execution of arbitrary instructions (B) Unintentional or ill-directed use of superficially supplied input (C) Cryptographic flaws in the system may get exploited to evade privacy (D) Weak or non-existent authentication mechanisms\n", "gold": 3, "choices": ["A", "B", "C", "D"], "category": "computer_security"} +{"query": "Question: Suppose Unix did not provide a way of passing file descriptors between processes, but still allowed inheriting file descriptors from a parent on fork and exec. What aspects of the OKWS design would break without file descriptor passing?\n1. It would be impossible for services to send messages to oklogd.\n2. It would be impossible for services to get a TCP connection to a database proxy.\n(A) True, True (B) False, False (C) True, False (D) False, True\n", "gold": 1, "choices": ["A", "B", "C", "D"], "category": "computer_security"} +{"query": "Question: Why would a ping sweep be used?\n(A) To identify live systems (B) To locate live systems (C) To identify open ports (D) To locate firewalls\n", "gold": 0, "choices": ["A", "B", "C", "D"], "category": "computer_security"} diff --git a/tests/eval/local_data/piqa_small.jsonl b/tests/eval/local_data/piqa_small.jsonl new file mode 100644 index 0000000000..07b1b27509 --- /dev/null +++ b/tests/eval/local_data/piqa_small.jsonl @@ -0,0 +1,4 @@ +{"choices": ["Pour it onto a plate", "Pour it into a jar"], "gold": 1, "query": "When boiling butter, when it's ready, you can"} +{"choices": ["Weld the metal together to get it to stay firmly in place", "Nail the metal together to get it to stay firmly in place"], "gold": 0, "query": "To permanently attach metal legs to a chair, you can"} +{"choices": ["leave a space before starting the writing", "press the spacebar"], "gold": 0, "query": "how do you indent something?"} +{"choices": ["move it up and down and side to side quickly.", "stir it very quickly."], "gold": 0, "query": "how do you shake something?"} diff --git a/tests/eval/local_data/pubmed_sm.jsonl b/tests/eval/local_data/pubmed_sm.jsonl new file mode 100644 index 0000000000..c39bab0b04 --- /dev/null +++ b/tests/eval/local_data/pubmed_sm.jsonl @@ -0,0 +1,4 @@ +{"context": "Context: PURPOSE. To assess whether eligibility to an adjuvant chemotherapy protocol in itself represents a good prognostic factor after radical cystectomy for bladder cancer.\nPATIENTS AND METHODS. Between April 1984 and May 1989, our institution entered 35 patients with invasive bladder cancer into the Swiss Group for Clinical and Epidemiological Cancer Research (SAKK) study 09/84. They were randomly assigned to either observation or three postoperative courses of cisplatin monotherapy after cystectomy. This study had a negative result. The outcome of these 35 patients (protocol group) was compared with an age- and tumor-stage-matched cohort (matched group; n = 35) who also underwent cystectomy during the same period, but were not entered into the SAKK study, as well as the remaining 57 patients treated during the study period for the same indication (remaining group).\nRESULTS. Median overall survival decreased from 76.3 months in the protocol group to 52.1 months in the matched group and to 20.3 months in the remaining group. The respective times of median recurrence-free survival were 67.2, 16.0, and 9.4 months. Tumor progression occurred in 46% of the protocol group compared with 69% in the matched group and 65% in the remaining group (P<.05). Cancer-related death was noted in 40% of the protocol group, 57% in the matched group, and 56% in the remaining group.\nQuestion: Is eligibility for a chemotherapy protocol a good prognostic factor for invasive bladder cancer after radical cystectomy?\nA. yes\nB. no\nC. maybe\nAnswer: ", "continuation": "yes"} +{"context": "Context: BACKGROUND. This study was performed to describe the treatment plan modifications after a geriatric oncology clinic. Assessment of health and functional status and cancer assessment was performed in older cancer patients referred to a cancer center.\nPATIENTS AND METHODS. Between June 2004 and May 2005, 105 patients 70 years old or older referred to a geriatric oncology consultation at the Institut Curie cancer center were included. Functional status, nutritional status, mood, mobility, comorbidity, medication, social support, and place of residence were assessed. Oncology data and treatment decisions were recorded before and after this consultation. Data were analyzed for a possible correlation between one domain of the assessment and modification of the treatment plan.\nRESULTS. Patient characteristics included a median age of 79 years and a predominance of women with breast cancer. About one half of patients had an independent functional status. Nearly 15% presented severe undernourishment. Depression was suspected in 53.1% of cases. One third of these patients had>2 chronic diseases, and 74% of patients took>or =3 medications. Of the 93 patients with an initial treatment decision, the treatment plan was modified for 38.7% of cases after this assessment. Only body mass index and the absence of depressive symptoms were associated with a modification of the treatment plan.\nQuestion: Does a geriatric oncology consultation modify the cancer treatment plan for elderly patients?\nA. yes\nB. no\nC. maybe\nAnswer: ", "continuation": "yes"} +{"context": "Context: BACKGROUND. The alterations of echocardiography and electrocardiogram (ECG) in patients received left atrial appendage LAA occlusion therapy are still unclear. The present study was to evaluate the influence of LAA occlusion device on echocardiography and ECG changes in patients with atrial fibrillation (AF).\nMETHODS. Seventy-three patients who had undergone Watchman, LAmbre and Lefort were enrolled in this study. Echocardiography and ECG results at pre- and post-operation were collected. Besides, echocardiography was also performed during follow-up visits at 1, 6 and 12months after discharge.\nRESULTS. After LAA occlusion, a slight and measureable movement of QRS electric axis was observed in most patients. The significant differences were also observed in heart rate (HR) and the mean-mean QT interval between pre- and post-operation for all patients. There existed no significant difference in echocardiographic parameters between before and after device implantation. However, a larger left atrial (LA) diameter was detected by echocardiography during follow-up visit at 6months when compared with pre-operation parameters. Similarly, aortic root diameter (ARD) was also larger during follow-up at 12months than the baseline dimension in pre-operation.\nQuestion: Does left atrial appendage (LAA) occlusion device alter the echocardiography and electrocardiogram parameters in patients with atrial fibrillation?\nA. yes\nB. no\nC. maybe\nAnswer: ", "continuation": "yes"} +{"context": "Context: BACKGROUND. Currently the choice of breast cancer therapy is based on prognostic factors. The proliferation marker Ki-67 is used increasingly to determine the method of therapy. The current study analyses the predictive value of Ki-67 in foreseeing breast cancer patients' responses to neoadjuvant chemotherapy.\nMETHODS. This study includes patients with invasive breast cancer treated between 2008 and 2013. The clinical response was assessed by correlating Ki-67 to histological examination, mammography, and ultrasonography findings.\nRESULTS. The average Ki-67 value in our patients collectively (n = 77) is 34.9 ± 24.6%. The average Ki-67 value is the highest with 37.4 ± 24.0% in patients with a pCR. The Ki-67 values do not differ significantly among the 3 groups: pCR versus partial pathological response versus stable disease/progress (P = 0.896). However, Ki-67 values of patients with luminal, Her2 enriched, and basal-like cancers differed significantly from each other. Furthermore, within the group of luminal tumors Ki-67 values of patients with versus without pCR also differed significantly.\nQuestion: Can ki-67 play a role in prediction of breast cancer patients' response to neoadjuvant chemotherapy?\nA. yes\nB. no\nC. maybe\nAnswer: ", "continuation": "yes"} diff --git a/tests/eval/local_data/triviaqa_small.jsonl b/tests/eval/local_data/triviaqa_small.jsonl new file mode 100644 index 0000000000..ae5e0783d9 --- /dev/null +++ b/tests/eval/local_data/triviaqa_small.jsonl @@ -0,0 +1,4 @@ +{"context": "Who was the man behind The Chipmunks?", "answer": "David Seville", "aliases": ["David Seville"]} +{"context": "What star sign is Jamie Lee Curtis?", "answer": "Scorpio", "aliases": ["Scorpio", "Skorpio"]} +{"context": "Which Lloyd Webber musical premiered in the US on 10th December 1993?", "answer": "Sunset Boulevard", "aliases": ["Sunset Blvd", "Sunset Boulevard", "Sunset Bulevard", "West Sunset Boulevard"]} +{"context": "Who was the next British Prime Minister after Arthur Balfour?", "answer": "Campbell-Bannerman", "aliases": ["Campbell Bannerman", "Campbell-Bannerman", "Henry Campbell Bannerman", "Henry Campbell-Bannerman", "Sir Henry Campbell Bannerman", "Sir Henry Campbell-Bannerman"]} diff --git a/tests/eval/local_data/winograd_small.jsonl b/tests/eval/local_data/winograd_small.jsonl new file mode 100644 index 0000000000..8f84cd27e5 --- /dev/null +++ b/tests/eval/local_data/winograd_small.jsonl @@ -0,0 +1,4 @@ +{"context_options": ["The city councilmen refused the demonstrators a permit because the city councilmen", "The city councilmen refused the demonstrators a permit because the demonstrators"], "continuation": "feared violence.", "gold": 0} +{"context_options": ["The city councilmen refused the demonstrators a permit because the city councilmen", "The city councilmen refused the demonstrators a permit because the demonstrators"], "continuation": "advocated violence.", "gold": 1} +{"context_options": ["The trophy doesn't fit into the brown suitcase because the trophy", "The trophy doesn't fit into the brown suitcase because the suitcase"], "continuation": "is too large.", "gold": 0} +{"context_options": ["The trophy doesn't fit into the brown suitcase because the trophy", "The trophy doesn't fit into the brown suitcase because the suitcase"], "continuation": "is too small.", "gold": 1} diff --git a/tests/eval/test_in_context_learning_datasets.py b/tests/eval/test_in_context_learning_datasets.py new file mode 100644 index 0000000000..33a041aaea --- /dev/null +++ b/tests/eval/test_in_context_learning_datasets.py @@ -0,0 +1,2841 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import contextlib +import os +import random +import types +from pathlib import Path +from typing import Dict, List, Optional + +import pytest +import torch +from composer import Evaluator +from composer.core import DataSpec +from torch.utils.data import DataLoader + +# isort: off +from llmfoundry.eval.datasets import ( + InContextLearningDataset, InContextLearningCodeEvalDataset, + InContextLearningMultipleChoiceTaskDataset, + InContextLearningGenerationTaskWithAnswersDataset, + InContextLearningSchemaTaskDataset, get_icl_task_dataloader, strip_data, + tokenizer_needs_prefix_space, trim_context, get_continuation_span, + get_fewshot_sample_idxs, make_padded_input) +# isort: on +import transformers +from composer.datasets.utils import MultiTokenEOSCriteria +from composer.loggers import InMemoryLogger +from composer.models import HuggingFaceModel +from composer.trainer import Trainer +from composer.utils import dist, reproducibility + +from llmfoundry.eval.metrics import ( + InContextLearningCodeEvalAccuracy, + InContextLearningGenerationExactMatchAccuracy, InContextLearningLMAccuracy, + InContextLearningMultipleChoiceAccuracy) + + +def test_strip_data(): + data_to_strip = { + 'strip_data': ' boo! \n', + 'has_space': ' wa hoo!', + 'end_space': 'yoohoo! ' + } + stripped_data = strip_data(data_to_strip) + for k, v in stripped_data.items(): + assert k in data_to_strip + assert not v[0].isspace() + assert not v[-1].isspace() + + +@pytest.mark.skip( + reason="Currently don't have a tokenizer that satisfies this test") +def test_tokenizer_needs_prefix_space_when_space_not_needed( + tiny_gpt2_tokenizer: transformers.AutoTokenizer): + assert not tokenizer_needs_prefix_space(tiny_gpt2_tokenizer) + + +def test_tokenizer_needs_prefix_space_when_space_needed(): + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'facebook/opt-125m', + use_fast=False) # type: ignore reportUnboundVariable + assert tokenizer_needs_prefix_space(tokenizer) + + +def test_trim_context(): + context = [0] * 99 + [1] * 2037 + continuation = [2] * 10 + max_seq_len = 2048 + trimmed_context = trim_context(context, + continuation, + max_seq_len=max_seq_len) + assert len(trimmed_context) == 2038 + assert trimmed_context[0] == 0 + assert trimmed_context[1] == 1 + + +def test_trim_context_no_continuation(): + context = [0] * 2048 + max_seq_len = 2048 + trimmed_context = trim_context(context, [], max_seq_len=max_seq_len) + assert len(trimmed_context) == 2048 + context = [0] * 3000 + [1] + max_seq_len = 2048 + trimmed_context = trim_context(context, [], max_seq_len=max_seq_len) + assert len(trimmed_context) == 2048 + assert trimmed_context[-1] == 1 + + +def test_get_continuation_span(): + context = [0] * 200 + continuation = [1] * 3 + cont_span = get_continuation_span(context, continuation) + assert torch.all(torch.eq(cont_span, torch.tensor([200, 201, 202]))) + continuation = [1] + cont_span = get_continuation_span(context, continuation) + assert torch.all(torch.eq(cont_span, torch.tensor([200]))) + + +@pytest.mark.parametrize('padding_side', ['left', 'right', 'middle']) +def test_make_padding(tiny_gpt2_tokenizer: transformers.AutoTokenizer, + padding_side: str): + context = tiny_gpt2_tokenizer(' cat' * 2000)['input_ids'] + padding_id = tiny_gpt2_tokenizer.eos_token_id + + error_context = contextlib.nullcontext() if padding_side in { + 'left', 'right' + } else pytest.raises(ValueError) + + with error_context: + input_ids = make_padded_input(context, [], + 2048, + padding_id, + padding_side=padding_side) + + if padding_side == 'left': + assert input_ids[0] == tiny_gpt2_tokenizer.eos_token_id + assert input_ids[48:].tolist() == context + elif padding_side == 'right': + assert input_ids[-1] == tiny_gpt2_tokenizer.eos_token_id + assert input_ids[:-48].tolist() == context + + +def test_batch_padding_logic_no_padding( + tiny_gpt2_tokenizer: transformers.AutoTokenizer): + continuation = tiny_gpt2_tokenizer(' dog' * 2000)['input_ids'] + context = tiny_gpt2_tokenizer(' cat' * 2000)['input_ids'] + max_seq_len = 2048 + trimmed_context = trim_context(context, continuation, max_seq_len) + continuation_spans = get_continuation_span(trimmed_context, continuation) + padded_input = make_padded_input(trimmed_context, + continuation, + max_seq_len, + tiny_gpt2_tokenizer.pad_token_id, + padding_side='right') + assert continuation_spans[0] == 48 and continuation_spans[-1] == 2047 + assert len(padded_input) == 2048 + assert tiny_gpt2_tokenizer.pad_token_id not in padded_input + + +def test_batch_padding_logic_with_padding( + tiny_gpt2_tokenizer: transformers.AutoTokenizer): + continuation = tiny_gpt2_tokenizer(' dog' * 200)['input_ids'] + context = tiny_gpt2_tokenizer(' cat' * 200)['input_ids'] + max_seq_len = 2048 + trimmed_context = trim_context(context, continuation, max_seq_len) + continuation_spans = get_continuation_span(trimmed_context, continuation) + padded_input = make_padded_input(trimmed_context, + continuation, + max_seq_len, + tiny_gpt2_tokenizer.pad_token_id, + padding_side='right') + assert continuation_spans[0] == 200 and continuation_spans[-1] == 399 + assert len(padded_input) == 2048 + assert padded_input[-1] == tiny_gpt2_tokenizer.pad_token_id + + +def test_fewshot_sample_idxs(): + rng = random.Random(1234) + + fewshot_idxs = get_fewshot_sample_idxs(dataset_size=5, + num_fewshot=4, + example_idx=4, + rng=rng) + assert fewshot_idxs == {0, 1, 2, 3} + + fewshot_idxs = get_fewshot_sample_idxs(dataset_size=5, + num_fewshot=5, + example_idx=4, + rng=rng) + assert fewshot_idxs == {0, 1, 2, 3} + + fewshot_idxs = get_fewshot_sample_idxs(dataset_size=5, + num_fewshot=500, + example_idx=4, + rng=rng) + assert fewshot_idxs == {0, 1, 2, 3} + + fewshot_idxs = get_fewshot_sample_idxs(dataset_size=10, + num_fewshot=7, + example_idx=4, + rng=rng) + assert len(fewshot_idxs) == 7 and 4 not in fewshot_idxs + + +def test_fewshot_sample_idxs_randomness(): + dataset_size = 10000 + num_fewshot = 5 + + rng_1_seed_1234 = random.Random(1234) + rng_2_seed_1234 = random.Random(1234) + rng_3_seed_11 = random.Random(11) + + rng_1_sample_1 = get_fewshot_sample_idxs(dataset_size, num_fewshot, 1, + rng_1_seed_1234) + rng_2_sample_1 = get_fewshot_sample_idxs(dataset_size, num_fewshot, 1, + rng_2_seed_1234) + rng_3_sample_1 = get_fewshot_sample_idxs(dataset_size, num_fewshot, 1, + rng_3_seed_11) + + assert rng_1_sample_1 == rng_2_sample_1 + assert rng_1_sample_1 != rng_3_sample_1 + + rng_1_sample_2 = get_fewshot_sample_idxs(dataset_size, num_fewshot, 2, + rng_1_seed_1234) + rng_2_sample_2 = get_fewshot_sample_idxs(dataset_size, num_fewshot, 2, + rng_2_seed_1234) + rng_3_sample_2 = get_fewshot_sample_idxs(dataset_size, num_fewshot, 2, + rng_3_seed_11) + + assert rng_1_sample_2 == rng_2_sample_2 + assert rng_1_sample_2 != rng_3_sample_2 + + +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' +) +def test_update_generation_kwargs( + tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path): + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + hf_loading_vars = { + 'split': 'test', + 'name': 'invoker', + } + hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']} + gen_kwargs = {'test_arg1': 1, 'test_arg2': 2} + + dl = InContextLearningDataset( + dataset_uri='hf://mosaicml/test_dataset', + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Orbs: ', + continuation_delimiter='\nSpell:', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=gen_kwargs) + assert dl.base_batch['generation_kwargs'] == { + 'test_arg1': 1, + 'test_arg2': 2 + } + + +def test_stop_sequences_criteria( + tiny_gpt2_tokenizer: transformers.AutoTokenizer): + eos_criteria = MultiTokenEOSCriteria('\n\n', tiny_gpt2_tokenizer, 2) + seq1 = tiny_gpt2_tokenizer('Dogs are furry')['input_ids'] + seq2 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids'] + seq1 = [tiny_gpt2_tokenizer.pad_token_id] * (len(seq2) - len(seq1)) + seq1 + input_ids = torch.LongTensor([seq1, seq2]) + assert not eos_criteria(input_ids, + None) # pyright: ignore[reportGeneralTypeIssues] + + eos_criteria = MultiTokenEOSCriteria('\n\n', tiny_gpt2_tokenizer, 2) + seq1 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids'] + seq2 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids'] + input_ids = torch.LongTensor([seq1, seq2]) + assert eos_criteria(input_ids, + None) # pyright: ignore[reportGeneralTypeIssues] + + +def test_stop_sequences_criteria_sentencepiece( + tiny_llama_tokenizer: transformers.AutoTokenizer): + + tokenizer = tiny_llama_tokenizer + eos_criteria = MultiTokenEOSCriteria('\n\n', tokenizer, 2) + seq1 = tokenizer( + '\n\nDogs' + )['input_ids'] # check to make sure starting with the stop sequence doesnt break it + seq2 = tokenizer('Dogs are furry\n\n')['input_ids'] + seq1 = [tokenizer.eos_token_id] * (len(seq2) - len(seq1)) + seq1 + input_ids = torch.LongTensor([seq1, seq2]) + assert not eos_criteria(input_ids, + None) # pyright: ignore[reportGeneralTypeIssues] + + eos_criteria = MultiTokenEOSCriteria('\n\n', tokenizer, 2) + seq1 = tokenizer('Dogs are furry\n\n')['input_ids'] + seq2 = tokenizer('Dogs are furry\n\n')['input_ids'] + input_ids = torch.LongTensor([seq1, seq2]) + assert eos_criteria(input_ids, + None) # pyright: ignore[reportGeneralTypeIssues] + + +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' +) +def test_update_generation_kwargs_no_kwargs( + tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path): + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + hf_loading_vars = { + 'split': 'test', + 'name': 'invoker', + } + hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']} + + dl = InContextLearningDataset( + dataset_uri='hf://mosaicml/test_dataset', + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Orbs: ', + continuation_delimiter='\nSpell:', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map) + assert not 'generation_kwargs' in dl.base_batch + + +def test_update_generation_kwargs_no_kwargs_qa_dataset(tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/triviaqa_small.jsonl' + + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'facebook/opt-125m') # type: ignore reportUnboundVariable + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = InContextLearningGenerationTaskWithAnswersDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + fewshot_random_seed=1234, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + generation_kwargs=None) + assert len(dl.base_batch['generation_kwargs']) == 4 + + +def test_update_generation_kwargs_with_kwargs_qa_dataset(tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/triviaqa_small.jsonl' + + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'facebook/opt-125m') # type: ignore reportUnboundVariable + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = InContextLearningGenerationTaskWithAnswersDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + fewshot_random_seed=1234, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + generation_kwargs={'temperature': 0.9}) + assert 'generation_kwargs' in dl.base_batch + assert dl.base_batch['generation_kwargs']['temperature'] == 0.9 + assert len(dl.base_batch['generation_kwargs']) == 5 + + +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' +) +def test_construct_context(tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path): + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + hf_loading_vars = { + 'split': 'test', + 'name': 'invoker', + } + hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']} + + dl = InContextLearningDataset( + dataset_uri='hf://mosaicml/test_dataset', + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Orbs: ', + continuation_delimiter='\nSpell: ', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map) + constructed_context = dl.construct_context({ + 'context': 'quas quas exort', + 'answer': 'ice wall' + }) + assert constructed_context == 'Orbs: quas quas exort\nSpell: ' + constructed_context = dl.construct_context( + { + 'context': 'quas quas exort', + 'answer': 'ice wall' + }, add_answer=True) + assert constructed_context == 'Orbs: quas quas exort\nSpell: ice wall' + constructed_context = dl.construct_context( + { + 'context': 'quas quas exort', + 'answer': 'ice wall' + }, + preceding_text='The harsh White Waste beckons!', + add_answer=True) + assert constructed_context == '\nOrbs: quas quas exort\nSpell: ice wall' + + +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' +) +def test_get_answer_from_example( + tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path): + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + hf_loading_vars = { + 'split': 'test', + 'name': 'invoker', + } + hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']} + + dl = InContextLearningDataset( + dataset_uri='hf://mosaicml/test_dataset', + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Orbs: ', + continuation_delimiter='\nSpell:', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map) + answer = dl.get_answer_from_example({ + 'context': 'wex exort exort', + 'answer': 'alacrity' + }) + assert answer == ' alacrity' + + +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' +) +def test_fix_eos_on_preamble(tmp_path: Path): + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'facebook/opt-125m', + use_fast=False) # type: ignore reportUnboundVariable + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + hf_loading_vars = { + 'split': 'test', + 'name': 'invoker', + } + hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']} + + dl = InContextLearningDataset( + dataset_uri='hf://mosaicml/test_dataset', + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Orbs: ', + continuation_delimiter='\nSpell:', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map) + preamble = 'blah blah blah.' + tokenized_preamble = tokenizer.encode(preamble) + tokenized_preamble += [tokenizer.eos_token_id] + fixed_preamble = dl._fix_eos_on_preamble(tokenized_preamble) + assert tokenized_preamble[:-1] == fixed_preamble + assert fixed_preamble[-1] != tokenizer.eos_token_id + + +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' +) +def test_tokenize_example_with_tokenize_labels( + tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path): + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + hf_loading_vars = { + 'split': 'test', + 'name': 'invoker', + } + hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']} + + dl = InContextLearningDataset( + dataset_uri='hf://mosaicml/test_dataset', + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Orbs: ', + continuation_delimiter='\nSpell: ', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + tokenize_labels=True) + tokenized_example = dl.tokenize_example('What spell does this invoke? ', + 'exort exort wex\nSpell: ', + {'answer': ' Meatball'}) + tokenized_input = [ + 2061, 4822, 857, 428, 26342, 30, 220, 1069, 419, 409, 419, 356, 87, 198, + 31221, 25, 19145, 1894 + ] + assert tokenized_example['context'][:len(tokenized_input)].tolist( + ) == tokenized_input + assert tokenized_example['context'][-1] == tokenizer.eos_token_id + assert type(tokenized_example['answer'][0]) == int + assert len(tokenized_example['context']) == seqlen + assert 'continuation_indices' in tokenized_example + + +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' +) +def test_tokenize_example_with_no_tokenize_labels( + tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path): + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + hf_loading_vars = { + 'split': 'test', + 'name': 'invoker', + } + hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']} + + dl = InContextLearningDataset( + dataset_uri='hf://mosaicml/test_dataset', + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Orbs: ', + continuation_delimiter='\nSpell: ', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + tokenize_labels=False) + tokenized_example = dl.tokenize_example('What spell does this invoke? ', + 'exort exort wex\nSpell: ', + {'answer': ' Meatball'}) + tokenized_input = [ + 2061, 4822, 857, 428, 26342, 30, 220, 1069, 419, 409, 419, 356, 87, 198, + 31221, 25 + ] + assert tokenized_example['context'][:len(tokenized_input)].tolist( + ) == tokenized_input + assert tokenized_example['context'][-1] == tokenizer.eos_token_id + assert len(tokenized_example['context']) == seqlen + assert type(tokenized_example['answer']) == str + + +def test_qa_set_cot_no_cot(tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/triviaqa_small.jsonl' + + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'facebook/opt-125m') # type: ignore reportUnboundVariable + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = InContextLearningGenerationTaskWithAnswersDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + fewshot_random_seed=1234, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + assert not dl.has_cot + + +def test_qa_set_cot_has_cot(tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/gsm8k_small.jsonl' + + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'facebook/opt-125m') # type: ignore reportUnboundVariable + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = InContextLearningGenerationTaskWithAnswersDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + fewshot_random_seed=1234, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + assert dl.has_cot + + +def test_qa_get_max_answer_length( + tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/triviaqa_small.jsonl' + tokenizer = tiny_gpt2_tokenizer + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = InContextLearningGenerationTaskWithAnswersDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + fewshot_random_seed=1234, + prompt_string='', + example_delimiter='', + continuation_delimiter='', + cot_delimiter='', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + # empirical number from the small test dataset + assert dl.max_answer_length == 7 + + +def test_qa_get_answer_from_example_with_no_cot( + tmp_path: Path, tiny_gpt2_tokenizer: transformers.AutoTokenizer): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/triviaqa_small.jsonl' + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = InContextLearningGenerationTaskWithAnswersDataset( + dataset_uri=dataset_uri, + tokenizer=tiny_gpt2_tokenizer, + max_seq_len=1024, + pad_tok_id=tiny_gpt2_tokenizer.eos_token_id, + num_fewshot=0, + fewshot_random_seed=1234, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + cot_delimiter=' ### ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + answer = dl.get_answer_from_example({ + 'context': 'empty', + 'answer': 'this is the correct answer', + 'chain_of_thought': "Let's think step by step. " + }) + assert answer == 'this is the correct answer' + + +def test_qa_get_answer_from_example_with_cot( + tmp_path: Path, tiny_gpt2_tokenizer: transformers.AutoTokenizer): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/triviaqa_small.jsonl' + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = InContextLearningGenerationTaskWithAnswersDataset( + dataset_uri=dataset_uri, + tokenizer=tiny_gpt2_tokenizer, + max_seq_len=1024, + pad_tok_id=tiny_gpt2_tokenizer.eos_token_id, + num_fewshot=0, + fewshot_random_seed=1234, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + cot_delimiter=' ### ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + dl.has_cot = True + answer = dl.get_answer_from_example({ + 'context': 'empty', + 'answer': 'this is the correct answer', + 'chain_of_thought': "Let's think step by step. " + }) + assert answer == "Let's think step by step. ### this is the correct answer" + + +def test_qa_tokenize_example(tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/triviaqa_small.jsonl' + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = InContextLearningGenerationTaskWithAnswersDataset( + dataset_uri=dataset_uri, + tokenizer=tiny_gpt2_tokenizer, + max_seq_len=1024, + pad_tok_id=tiny_gpt2_tokenizer.eos_token_id, + num_fewshot=0, + fewshot_random_seed=1234, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + cot_delimiter=' ### ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + dl.has_cot = True + tokenized_example = dl.tokenize_example( + 'starting prompt', 'a context', { + 'context': 'empty', + 'answer': 'this is the correct answer', + 'aliases': ['this is the right answer', 'this is the best answer'], + 'chain_of_thought': "Let's think step by step. " + }) + assert 'aliases' in tokenized_example + assert tokenized_example['aliases'] == [ + 'this is the right answer', 'this is the best answer' + ] + + +def test_code_adjust_padding(tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/human_eval_small.jsonl' + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + gen_kwargs = {'temperature': .9, 'top_p': .95, 'num_beams': 9000} + + dl = InContextLearningCodeEvalDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Code start:', + continuation_delimiter='\nPlease code:', + destination_path=str(tmp_path / 'test_human_eval_small.jsonl'), + generation_kwargs=gen_kwargs, + generations_per_sample=10, + ) + + assert all( + len(data['prompt']) == 148 + for data in dl.dataset) # pyright: ignore [reportGeneralTypeIssues] + + +def test_code_update_gen_kwargs(tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/human_eval_small.jsonl' + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + gen_kwargs = {'temperature': .9, 'top_p': .95, 'num_beams': 9000} + + dl = InContextLearningCodeEvalDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Code start:', + continuation_delimiter='\nPlease code:', + destination_path=str(tmp_path / 'test_human_eval_small.jsonl'), + generation_kwargs=gen_kwargs, + generations_per_sample=10, + ) + assert dl.base_batch['generation_kwargs']['num_beams'] == 9000 + assert dl.base_batch['generation_kwargs']['top_p'] == .95 + assert dl.base_batch['generation_kwargs']['temperature'] == .9 + assert dl.base_batch['generation_kwargs']['do_sample'] == True + + +def test_mc_tokenize_example(tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/mmlu_small.jsonl' + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + seqlen = 2048 + dl = InContextLearningMultipleChoiceTaskDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + continuation_delimiter=' ### ', + destination_path=str(tmp_path / 'test_human_eval_small.jsonl'), + ) + example = { + 'context': + "Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: ", + 'choices': ['A', 'B', 'C', 'D'], + 'gold': + 2 + } + tokenized_example = dl.tokenize_example( + prompt_and_fewshot='Answer the following: ', + ctxt=example['context'], + example=example) + unpadded_queries = [ + context[context != tokenizer.eos_token_id] + for context in tokenized_example['query'] + ] + untokenized_inputs = [ + tokenizer.decode(unpadded_input) for unpadded_input in unpadded_queries + ] + correct_output = [ + "Answer the following: Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: A", + "Answer the following: Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: B", + "Answer the following: Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: C", + "Answer the following: Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: D" + ] + assert untokenized_inputs == correct_output + + +@pytest.mark.parametrize('prelimiter', ['', 'This is a question: ']) +def test_schema_construct_context( + prelimiter: str, tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/winograd_small.jsonl' + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + seqlen = 2048 + dl = InContextLearningSchemaTaskDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string='', + prelimiter=prelimiter, + example_delimiter='\n', + continuation_delimiter=' ### ', + destination_path=str(tmp_path / 'test_human_eval_small.jsonl'), + ) + example = { + 'context_options': ['cont one', 'cont two'], + 'gold': 0, + 'continuation': 'this is a continuation' + } + constructed_context = dl.construct_context(example) + assert constructed_context == f'{prelimiter}cont one ### this is a continuation' + constructed_context = dl.construct_context(example, preceding_text='text') + assert constructed_context == f'{prelimiter}\ncont one ### this is a continuation' + + +@pytest.mark.parametrize('prelimiter', ['', 'This is a question: ']) +def test_schema_construct_multiple_contexts( + prelimiter: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, +): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/winograd_small.jsonl' + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + seqlen = 2048 + dl = InContextLearningSchemaTaskDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prelimiter=prelimiter, + prompt_string=prompt_string, + example_delimiter='\n', + continuation_delimiter=' ### ', + destination_path=str(tmp_path / 'test_human_eval_small.jsonl'), + ) + example = { + 'context_options': [f'cont one', 'cont two'], + 'gold': 0, + 'continuation': 'this is a continuation' + } + constructed_contexts = dl._construct_multiple_contexts(example) + assert constructed_contexts == [ + f'{prelimiter}cont one', f'{prelimiter}cont two' + ] + constructed_contexts = dl._construct_multiple_contexts( + example, preceding_text='some text') + assert constructed_contexts == [ + f'{prelimiter}\ncont one ###', f'{prelimiter}\ncont two ###' + ] + + +def test_schema_tokenize_example( + tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/winograd_small.jsonl' + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + seqlen = 2048 + dl = InContextLearningSchemaTaskDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, # pyright: ignore + example_delimiter='\n', # pyright: ignore + continuation_delimiter=' ### ', + destination_path=str(tmp_path / + 'test_human_eval_small.jsonl'), # pyright: ignore + ) + example = { + 'context_options': ['context one', 'context two'], + 'gold': 0, + 'continuation': 'this is a continuation' + } + tokenized_example = dl.tokenize_example( + prompt_and_fewshot='prompt ', + context_options=example['context_options'], + example=example) + assert all( + tiny_gpt2_tokenizer.decode(cont) == ' this is a continuation' + for cont in tokenized_example['answer']) + unpadded_inputs = [ + context[context != tokenizer.eos_token_id] + for context in tokenized_example['context_options'] + ] + untokenized_inputs = [ + tokenizer.decode(unpadded_input) for unpadded_input in unpadded_inputs + ] + assert untokenized_inputs == [ + 'prompt context one this is a continuation', + 'prompt context two this is a continuation' + ] + + +@pytest.mark.parametrize('dataset_uri', ['mmlu_small.jsonl']) +def test_mc_task_dataloader_subcategories( + dataset_uri: str, tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_gpt2_tokenizer + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 8 + seqlen = 64 + dls = get_icl_task_dataloader( + 'multiple_choice', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=2, + prompt_string= + 'The following are multiple choice questions (with answers).\n', + example_delimiter='\n', + continuation_delimiter='Answer: ', + destination_path=str(tmp_path / 'icl.jsonl'), + has_categories=True) + assert isinstance(dls, dict) + + assert 'computer_security' in dls + dl = dls['computer_security'] + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + assert dl.dataloader.__len__() == 2 + assert 'input_ids' in batch + assert tuple(batch['input_ids'].shape) == (batch_size, seqlen) + assert 'attention_mask' in batch + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) + assert 'continuation_indices' in batch + assert isinstance(batch['continuation_indices'], list) and len( + batch['continuation_indices']) == batch_size + assert 'mode' in batch + assert batch['mode'] == 'icl_task' + min_idx = min(batch['continuation_indices'][0]).item() + max_idx = max(batch['continuation_indices'][0]).item() + assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + 1]) == ' A' + + +@pytest.mark.parametrize('dataset_uri', [ + 'pubmed_sm.jsonl', +]) +def test_lm_task_dataloader_extra_space( + dataset_uri: str, tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_gpt2_tokenizer + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 2 + seqlen = 64 + dl = get_icl_task_dataloader('language_modeling', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=10, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=' ', + destination_path=str(tmp_path / 'icl.jsonl')) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + + assert 'input_ids' in batch + assert tuple(batch['input_ids'].shape) == (batch_size, seqlen) + assert 'attention_mask' in batch + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) + assert 'continuation_indices' in batch + assert isinstance(batch['continuation_indices'], list) and len( + batch['continuation_indices']) == batch_size + assert 'mode' in batch + assert batch['mode'] == 'icl_task' + min_idx = min(batch['continuation_indices'][0]).item() + max_idx = max(batch['continuation_indices'][0]).item() + assert ' ' not in tokenizer.decode(batch['input_ids'][0][0:max_idx + 1]) + assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + + 1]) == ' yes' + + +@pytest.mark.parametrize('dataset_uri', [ + 'lambada_small.jsonl', +]) +def test_lm_task_dataloader(dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_gpt2_tokenizer + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 2 + seqlen = 64 + dl = get_icl_task_dataloader('language_modeling', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + prompt_string='', + example_delimiter='\n', + continuation_delimiter='', + destination_path=str(tmp_path / 'icl.jsonl')) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + + assert 'input_ids' in batch + assert tuple(batch['input_ids'].shape) == (batch_size, seqlen) + assert 'attention_mask' in batch + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) + assert 'continuation_indices' in batch + assert isinstance(batch['continuation_indices'], list) and len( + batch['continuation_indices']) == batch_size + assert 'mode' in batch + assert batch['mode'] == 'icl_task' + min_idx = min(batch['continuation_indices'][0]).item() + max_idx = max(batch['continuation_indices'][0]).item() + assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + + 1]) == ' glen' + + +@pytest.mark.parametrize('dataset_uri', ['winograd_small.jsonl']) +@pytest.mark.parametrize('prelimiter', ['', 'This is a question: ']) +def test_schema_task_dataloader(dataset_uri: str, prelimiter: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_gpt2_tokenizer + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 2 + seqlen = 64 + dl = get_icl_task_dataloader('schema', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=1, + prompt_string='', + example_delimiter='\n', + question_prelimiter=prelimiter, + continuation_delimiter='', + destination_path=str(tmp_path / 'icl.jsonl')) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) + batch = next(dl.dataloader._get_iterator()) + + choices_per_question = 2 + assert 'input_ids' in batch + assert tuple(batch['input_ids'].shape) == (batch_size, seqlen) + assert 'attention_mask' in batch + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) + assert 'continuation_indices' in batch + assert isinstance(batch['continuation_indices'], list) and len( + batch['continuation_indices']) == batch_size + assert 'mode' in batch + assert batch['mode'] == 'icl_task' + assert 'gold_indices' in batch + assert isinstance(batch['gold_indices'], list) and len( + batch['gold_indices']) == batch_size // choices_per_question + assert 'choice_groupings' in batch + assert isinstance(batch['choice_groupings'], list) and len( + batch['choice_groupings']) == batch_size // choices_per_question + + min_idx = min(batch['continuation_indices'][0]).item() + max_idx = max(batch['continuation_indices'][0]).item() + assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + + 1]) == ' feared violence.' + + +@pytest.mark.parametrize('dataset_uri', ['winograd_small.jsonl']) +def test_schema_task_dataloader_sentpiece_tokenizer(dataset_uri: str, + tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'huggyllama/llama-7b', # type: ignore reportUnboundVariable + use_fast=False) + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 2 + seqlen = 64 + dl = get_icl_task_dataloader('schema', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=1, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=' ', + destination_path=str(tmp_path / 'icl.jsonl')) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) + batch = next(dl.dataloader._get_iterator()) + + choices_per_question = 2 + assert 'input_ids' in batch + assert tuple(batch['input_ids'].shape) == (batch_size, seqlen) + assert 'attention_mask' in batch + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) + assert 'continuation_indices' in batch + assert isinstance(batch['continuation_indices'], list) and len( + batch['continuation_indices']) == batch_size + assert 'mode' in batch + assert batch['mode'] == 'icl_task' + assert 'gold_indices' in batch + assert isinstance(batch['gold_indices'], list) and len( + batch['gold_indices']) == batch_size // choices_per_question + assert 'choice_groupings' in batch + assert isinstance(batch['choice_groupings'], list) and len( + batch['choice_groupings']) == batch_size // choices_per_question + + max_idx = max(batch['continuation_indices'][0]).item() + assert tokenizer.decode( + batch['input_ids'][0][0:max_idx + 1] + ) == "The trophy doesn't fit into the brown suitcase because the suitcase is too small. \nThe city councilmen refused the demonstrators a permit because the city councilmen feared violence." + + +@pytest.mark.parametrize('dataset_uri', ['lambada_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 1]) +def test_lm_task_dataloader_opt_tokenizer( + tiny_opt_tokenizer: transformers.AutoTokenizer, dataset_uri: str, + num_fewshot: int, tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_opt_tokenizer + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 2 + seqlen = 512 + dl = get_icl_task_dataloader('language_modeling', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter='', + destination_path=str(tmp_path / 'icl.jsonl')) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + + assert 'input_ids' in batch + assert tuple(batch['input_ids'].shape) == (batch_size, seqlen) + assert 'attention_mask' in batch + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) + assert 'continuation_indices' in batch + assert isinstance(batch['continuation_indices'], list) and len( + batch['continuation_indices']) == batch_size + assert 'mode' in batch + assert batch['mode'] == 'icl_task' + min_idx = min(batch['continuation_indices'][0]).item() + max_idx = max(batch['continuation_indices'][0]).item() + assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + + 1]) == ' glen' + assert tokenizer.decode(batch['input_ids'][0][0:min_idx]).startswith('') + assert tokenizer.decode(batch['input_ids'][0][0:min_idx]).count('') == 1 + + +@pytest.mark.parametrize('dataset_uri', ['piqa_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 1]) +def test_mc_task_dataloader_opt_tokenizer( + tiny_opt_tokenizer: transformers.AutoTokenizer, dataset_uri: str, + num_fewshot: int, tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_opt_tokenizer + + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 4 + seqlen = 64 + dl = get_icl_task_dataloader('multiple_choice', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(tmp_path / 'icl.jsonl')) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + + choices_per_question = 2 + assert dl.get_num_samples_in_batch(batch) == 2 + assert 'input_ids' in batch + assert tuple(batch['input_ids'].shape) == (batch_size, seqlen) + assert 'attention_mask' in batch + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) + assert 'continuation_indices' in batch + assert isinstance(batch['continuation_indices'], list) and len( + batch['continuation_indices']) == batch_size + assert 'mode' in batch + assert batch['mode'] == 'icl_task' + assert 'gold_indices' in batch + assert isinstance(batch['gold_indices'], list) and len( + batch['gold_indices']) == batch_size // choices_per_question + assert 'choice_groupings' in batch + assert isinstance(batch['choice_groupings'], list) and len( + batch['choice_groupings']) == batch_size // choices_per_question + + min_idx = min(batch['continuation_indices'][0]).item() + max_idx = max(batch['continuation_indices'][0]).item() + assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + + 1]) == ' Pour it onto a plate' + assert tokenizer.decode(batch['input_ids'][0][0:min_idx]).startswith('') + assert tokenizer.decode(batch['input_ids'][0][0:min_idx]).count('') == 1 + + +@pytest.mark.parametrize('dataset_uri', ['piqa_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 1]) +def test_mc_split_batch(tiny_opt_tokenizer: transformers.AutoTokenizer, + dataset_uri: str, num_fewshot: int, tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_opt_tokenizer + + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 4 + seqlen = 512 + dl = get_icl_task_dataloader('multiple_choice', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(tmp_path / 'icl.jsonl')) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + choices_per_question = 2 + real_microbatch_size = batch_size // 2 + logical_microbatch_size = real_microbatch_size // choices_per_question + microbatches = dl.split_batch(batch, logical_microbatch_size) + assert len(microbatches) == 2 + for i, microbatch in enumerate(microbatches): + assert dl.get_num_samples_in_batch(microbatch) == 1 + assert 'input_ids' in microbatch + assert tuple(microbatch['input_ids'].shape) == (real_microbatch_size, + seqlen) + assert 'attention_mask' in microbatch + assert tuple( + microbatch['attention_mask'].shape) == (real_microbatch_size, + seqlen) + assert 'continuation_indices' in microbatch + assert isinstance(microbatch['continuation_indices'], list) and len( + microbatch['continuation_indices']) == real_microbatch_size + assert 'mode' in microbatch + assert microbatch['mode'] == 'icl_task' + assert 'gold_indices' in microbatch + assert isinstance(microbatch['gold_indices'], list) and len( + microbatch['gold_indices'] + ) == real_microbatch_size // choices_per_question + assert 'choice_groupings' in microbatch + assert isinstance(microbatch['choice_groupings'], list) and len( + microbatch['choice_groupings'] + ) == real_microbatch_size // choices_per_question + + min_idx = min(microbatch['continuation_indices'][0]).item() + max_idx = max(microbatch['continuation_indices'][0]).item() + if i == 0: + assert tokenizer.decode( + microbatch['input_ids'][0][min_idx:max_idx + + 1]) == ' Pour it onto a plate' + elif i == 1: + assert tokenizer.decode( + microbatch['input_ids'][0][min_idx:max_idx + 1] + ) == ' Weld the metal together to get it to stay firmly in place' + assert tokenizer.decode( + microbatch['input_ids'][0][0:min_idx]).startswith('') + assert tokenizer.decode( + microbatch['input_ids'][0][0:min_idx]).count('') == 1 + + +@pytest.mark.parametrize('dataset_uri', ['triviaqa_small.jsonl']) +def test_qa_split_batch(tiny_opt_tokenizer: transformers.AutoTokenizer, + dataset_uri: str, tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + tokenizer = tiny_opt_tokenizer + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) # for dist + dl = get_icl_task_dataloader( + icl_task_type='generation_task_with_answers', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=8, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + + assert isinstance(dl, DataSpec) # pyright + + batch = next(iter(dl.dataloader)) + split_batch = dl.split_batch(batch, 3) + + assert len(split_batch) == 2 + split1 = split_batch[0] + split2 = split_batch[1] + + assert split1['input_ids'].shape[0] == 3 + assert split2['input_ids'].shape[0] == 1 + + assert split1['attention_mask'].shape[0] == 3 + assert split2['attention_mask'].shape[0] == 1 + + assert isinstance(split1['mode'], str) + assert isinstance(split2['mode'], str) + + assert len(split1['labels']) == 3 + assert len(split2['labels']) == 1 + assert all(isinstance(v, list) for v in split1['labels'] + split2['labels']) + + assert isinstance(split1['generation_kwargs']['max_new_tokens'], int) + assert isinstance(split2['generation_kwargs']['max_new_tokens'], int) + + assert isinstance(split1['generation_kwargs'], dict) + assert isinstance(split2['generation_kwargs'], dict) + + +@pytest.mark.parametrize('dataset_uri', ['triviaqa_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0]) +@pytest.mark.parametrize('prompt_string', ['I am a prompt', '']) +def test_qa_task_dataloader_w_null_eos( + dataset_uri: str, tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, num_fewshot: int, prompt_string: str): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_gpt2_tokenizer + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 4 + seqlen = 512 + tiny_gpt2_tokenizer.eos_token_id = None + with pytest.raises(ValueError): + _ = get_icl_task_dataloader('generation_task_with_answers', + dataset_uri, + tokenizer, + batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter='\n', + question_prelimiter='Q: ', + continuation_delimiter='\nA:', + destination_path=str( + tmp_path / f'icl_{num_fewshot}.jsonl')) + + +@pytest.mark.parametrize('dataset_uri', ['triviaqa_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 2]) +@pytest.mark.parametrize('prompt_string', ['I am a prompt', '']) +def test_qa_task_dataloader(dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, num_fewshot: int, + prompt_string: str): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_gpt2_tokenizer + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 4 + seqlen = 512 + # empirical number from the small test dataset + maximum_answer_length = 7 + dl = get_icl_task_dataloader('generation_task_with_answers', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter='\n', + question_prelimiter='Q: ', + continuation_delimiter='\nA:', + destination_path=str( + tmp_path / f'icl_{num_fewshot}.jsonl')) + assert isinstance(dl, DataSpec) + + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + + assert tuple(batch['input_ids'].shape) == (batch_size, + seqlen - maximum_answer_length) + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen - + maximum_answer_length) + assert batch['mode'] == 'generate' + # the maximum generation length from the small test data + + assert batch['generation_kwargs']['max_new_tokens'] == maximum_answer_length + assert all(item[0] == tokenizer.eos_token_id for item in batch['input_ids']) + + decoded_batch = tokenizer.batch_decode(batch['input_ids']) + assert all(item.count('Q: ') == num_fewshot + 1 for item in decoded_batch) + assert all(item.count('\nA:') == num_fewshot + 1 for item in decoded_batch) + + if len(prompt_string) > 0: + assert all(item.count('I am a prompt') == 1 for item in decoded_batch) + assert all( + set(found) == set(expected) for found, expected in zip( + batch['labels'], [['David Seville'], ['Skorpio', 'Scorpio']])) + assert decoded_batch[0].endswith( + 'Q: Who was the man behind The Chipmunks?\nA:') + assert decoded_batch[1].endswith( + 'Q: What star sign is Jamie Lee Curtis?\nA:') + assert 'eos_token_id' in batch['generation_kwargs'] + + +@pytest.mark.parametrize('dataset_uri', ['gsm8k_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 2]) +def test_qa_task_with_cot_dataloader( + dataset_uri: str, tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, num_fewshot: int): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_gpt2_tokenizer + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 2 + seqlen = 512 + # empirical number from the small test dataset + maximum_answer_length = 132 + dl = get_icl_task_dataloader( + 'generation_task_with_answers', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + question_prelimiter='Q: ', + continuation_delimiter="\nA: Let's think step by step. ", + cot_delimiter=' #### ', + destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl')) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + assert tuple(batch['input_ids'].shape) == (batch_size, + seqlen - maximum_answer_length) + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen - + maximum_answer_length) + assert batch['mode'] == 'generate' + # the maximum generation length from the small test data + assert batch['generation_kwargs']['max_new_tokens'] == maximum_answer_length + assert all(item[0] == tokenizer.eos_token_id for item in batch['input_ids']) + decoded_batch = tokenizer.batch_decode(batch['input_ids']) + assert all(item.count('Q: ') == num_fewshot + 1 for item in decoded_batch) + assert all(item.count('\nA:') == num_fewshot + 1 for item in decoded_batch) + + assert batch['labels'] == [['18'], ['3']] + if num_fewshot == 0: + assert decoded_batch[0].endswith( + "Q: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\nA: Let's think step by step." + ) + assert decoded_batch[1].endswith( + "Q: A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?\nA: Let's think step by step." + ) + elif num_fewshot == 2: + assert decoded_batch[0].endswith( + "Q: Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs. This increased the value of the house by 150%. How much profit did he make?\nA: Let's think step by step. The cost of the house and repairs came out to 80,000+50,000=$<<80000+50000=130000>>130,000\nHe increased the value of the house by 80,000*1.5=<<80000*1.5=120000>>120,000\nSo the new value of the house is 120,000+80,000=$<<120000+80000=200000>>200,000\nSo he made a profit of 200,000-130,000=$<<200000-130000=70000>>70,000 #### 70000\nQ: James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week?\nA: Let's think step by step. He sprints 3*3=<<3*3=9>>9 times\nSo he runs 9*60=<<9*60=540>>540 meters #### 540\nQ: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\nA: Let's think step by step." + ) + assert decoded_batch[1].endswith( + "Q: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\nA: Let's think step by step. Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.\nShe makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market. #### 18\nQ: Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs. This increased the value of the house by 150%. How much profit did he make?\nA: Let's think step by step. The cost of the house and repairs came out to 80,000+50,000=$<<80000+50000=130000>>130,000\nHe increased the value of the house by 80,000*1.5=<<80000*1.5=120000>>120,000\nSo the new value of the house is 120,000+80,000=$<<120000+80000=200000>>200,000\nSo he made a profit of 200,000-130,000=$<<200000-130000=70000>>70,000 #### 70000\nQ: A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?\nA: Let's think step by step." + ) + + +@pytest.mark.parametrize('dataset_uri', ['piqa_small.jsonl']) +@pytest.mark.parametrize('prelimiter', ['', 'This is a question: ']) +def test_mc_task_dataloader(dataset_uri: str, prelimiter: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_gpt2_tokenizer + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 2 + seqlen = 64 + example_delimiter = '\n' + dl = get_icl_task_dataloader('multiple_choice', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=1, + prompt_string='', + question_prelimiter=prelimiter, + example_delimiter=example_delimiter, + continuation_delimiter='\nA: ', + destination_path=str(tmp_path / 'icl.jsonl')) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + + choices_per_question = 2 + assert 'input_ids' in batch + assert tuple(batch['input_ids'].shape) == (batch_size, seqlen) + assert 'attention_mask' in batch + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) + assert 'continuation_indices' in batch + assert isinstance(batch['continuation_indices'], list) and len( + batch['continuation_indices']) == batch_size + assert 'mode' in batch + assert batch['mode'] == 'icl_task' + assert 'gold_indices' in batch + assert isinstance(batch['gold_indices'], list) and len( + batch['gold_indices']) == batch_size // choices_per_question + assert 'choice_groupings' in batch + assert isinstance(batch['choice_groupings'], list) and len( + batch['choice_groupings']) == batch_size // choices_per_question + + min_idx = min(batch['continuation_indices'][0]).item() + max_idx = max(batch['continuation_indices'][0]).item() + assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + + 1]) == ' Pour it onto a plate' + q1 = 'how do you shake something?\nA: ' + a1 = 'move it up and down and side to side quickly.' + q2 = "When boiling butter, when it's ready, you can\nA:" + assert tokenizer.decode( + batch['input_ids'][0][:min_idx] + ) == f'{prelimiter}{q1}{a1}{example_delimiter}{prelimiter}{q2}' + assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + + 1]) == ' Pour it onto a plate' + + +@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl']) +def test_code_eval_split_batch(dataset_uri: str, tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'EleutherAI/gpt-neox-20b') # type: ignore reportUnboundVariable + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = get_icl_task_dataloader( + 'code_evaluation', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=5, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=2, + prompt_string='', + example_delimiter='\n', + continuation_delimiter='', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + generations_per_sample=3, + ) + + assert isinstance(dl, DataSpec) # pyright + batches = list(dl.dataloader) + + for k in ('input_ids', 'attention_mask'): + assert [b[k].shape[0] for b in batches] == [5, 5, 2] + + list_keys = { + 'labels': str, + 'prompts': str, + 'tests': str, + 'entry_points': str, + 'test_inputs': list, + 'test_outputs': list, + 'languages': str, + } + + for batch, size in zip(batches, [5, 5, 2]): + for field, type_ in list_keys.items(): + assert len(batch[field]) == size + assert all(isinstance(val, type_) for val in batch[field]) + + static_keys = {'pass_at_k': (int, list), 'generation_kwargs': dict} + for batch in batches: + assert 'generation_kwargs' in batch + assert 'max_new_tokens' in batch['generation_kwargs'] + assert isinstance(batch['generation_kwargs']['max_new_tokens'], int) + for field, type_ in static_keys.items(): + assert isinstance(batch[field], type_) + + +@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 2]) +@pytest.mark.parametrize('prompt_string', ['Please code:\n', '']) +@pytest.mark.parametrize('generations_per_sample', [1, 3]) +def test_code_eval_sentpiece_dataloader( + dataset_uri: str, tmp_path: Path, num_fewshot: int, prompt_string: str, + generations_per_sample: int, + tiny_llama_tokenizer: transformers.AutoTokenizer): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_llama_tokenizer + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 5 + seqlen = 2048 + + dl = get_icl_task_dataloader('code_evaluation', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter='\n', + continuation_delimiter='', + question_prelimiter='Code start: \n', + destination_path=str( + tmp_path / f'icl_{num_fewshot}.jsonl'), + generations_per_sample=generations_per_sample) + assert isinstance(dl, DataSpec) + + assert isinstance(dl.dataloader, DataLoader) # pyright + batches = list(dl.dataloader) + dataset_size = len(open(dataset_uri, 'r').read().strip().split('\n')) + dataset_size *= generations_per_sample + + max_prompt_length = 0 + + has_left_padding = [] + for i, batch in enumerate(batches): + if isinstance(dl.dataloader.dataset, InContextLearningCodeEvalDataset): + max_prompt_length = dl.dataloader.dataset.max_prompt_length + N = len(batches) + bs = batch_size if i < N - 1 else dataset_size - (N - 1) * batch_size + assert tuple(batch['input_ids'].shape) == (bs, max_prompt_length) + assert tuple(batch['attention_mask'].shape) == (bs, max_prompt_length) + assert batch['mode'] == 'generate' + # the maximum generation length from the small test data + assert batch['generation_kwargs']['max_new_tokens'] == 129 + has_left_padding.extend( + [item[0] == tokenizer.eos_token_id for item in batch['input_ids']]) + assert not all(has_left_padding) # longest should be pushed left + + decoded_batches = [ + tokenizer.batch_decode(batch['input_ids']) for batch in batches + ] + for decoded_batch in decoded_batches: + assert all( + item.count('Code start: \n') == num_fewshot + 1 + for item in decoded_batch) + + if len(prompt_string) > 0: + assert all( + item.count('Please code:\n') == 1 for item in decoded_batch) + + labels = [ + ' for idx, elem in enumerate(numbers):\n for idx2, elem2 in enumerate(numbers):\n if idx != idx2:\n distance = abs(elem - elem2)\n if distance < threshold:\n return True\n\n return False\n', + " result = []\n current_string = []\n current_depth = 0\n\n for c in paren_string:\n if c == '(':\n current_depth += 1\n current_string.append(c)\n elif c == ')':\n current_depth -= 1\n current_string.append(c)\n\n if current_depth == 0:\n result.append(''.join(current_string))\n current_string.clear()\n\n return result\n", + ' return number % 1.0\n', + ' balance = 0\n\n for op in operations:\n balance += op\n if balance < 0:\n return True\n\n return False\n', + ] + + # assert decoded_batch[0].endswith( + samples = [ + "Code start: \nfrom typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n \"\"\"\n", + "Code start: \nfrom typing import List\n\n\ndef separate_paren_groups(paren_string: str) -> List[str]:\n \"\"\" Input to this function is a string containing multiple groups of nested parentheses. Your goal is to\n separate those group into separate strings and return the list of those.\n Separate groups are balanced (each open brace is properly closed) and not nested within each other\n Ignore any spaces in the input string.\n >>> separate_paren_groups('( ) (( )) (( )( ))')\n ['()', '(())', '(()())']\n \"\"\"\n", + "Code start: \n\n\ndef truncate_number(number: float) -> float:\n \"\"\" Given a positive floating point number, it can be decomposed into\n and integer part (largest integer smaller than given number) and decimals\n (leftover part always smaller than 1).\n\n Return the decimal part of the number.\n >>> truncate_number(3.5)\n 0.5\n \"\"\"\n", + "Code start: \nfrom typing import List\n\n\ndef below_zero(operations: List[int]) -> bool:\n \"\"\" You're given a list of deposit and withdrawal operations on a bank account that starts with\n zero balance. Your task is to detect if at any point the balance of account fallls below zero, and\n at that point function should return True. Otherwise it should return False.\n >>> below_zero([1, 2, 3])\n False\n >>> below_zero([1, 2, -4, 5])\n True\n \"\"\"\n" + ] + for i in range(4): + for j in range(generations_per_sample): + k = i * generations_per_sample + j + b, n = divmod(k, batch_size) + assert batches[b]['labels'][n] == labels[i] + assert decoded_batches[b][n].endswith(samples[i]) + + +@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl']) +def test_code_eval_test_cases(dataset_uri: str, tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'huggyllama/llama-7b') # type: ignore reportUnboundVariable + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 4 + seqlen = 512 + + dl = get_icl_task_dataloader('code_evaluation', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + prompt_string='', + example_delimiter='\n', + continuation_delimiter='', + question_prelimiter='Code start: \n', + destination_path=str(tmp_path / f'icl_.jsonl'), + generations_per_sample=1) + assert isinstance(dl, DataSpec) + + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + + max_prompt_length = 0 + if isinstance(dl.dataloader.dataset, InContextLearningCodeEvalDataset): + max_prompt_length = dl.dataloader.dataset.max_prompt_length + assert tuple(batch['input_ids'].shape) == (batch_size, max_prompt_length) + assert tuple(batch['attention_mask'].shape) == (batch_size, + max_prompt_length) + assert batch['mode'] == 'generate' + # the maximum generation length from the small test data + assert batch['generation_kwargs']['max_new_tokens'] == 129 + assert any(item[0] != tokenizer.eos_token_id + for item in batch['input_ids']) # longest should be pushed left + + mod = types.ModuleType('test_module') + for prompt, solution, inputs, outputs, entry_point in zip( + batch['prompts'], batch['labels'], batch['test_inputs'], + batch['test_outputs'], batch['entry_points']): + exec(prompt + solution, mod.__dict__) + for test_input, test_output in zip(inputs, outputs): + result = mod.__dict__[entry_point](*eval(test_input)) + assert result == eval(test_output) + + +@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl']) +def test_code_eval_pass_at_k_validity(dataset_uri: str, tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'huggyllama/llama-7b') # type: ignore reportUnboundVariable + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 2 + seqlen = 64 + + with pytest.raises(ValueError, match=r'.* pass_at_k .*'): + get_icl_task_dataloader('code_evaluation', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + prompt_string='', + example_delimiter='\n', + continuation_delimiter='', + question_prelimiter='Code start: \n', + destination_path=str(tmp_path / f'icl_.jsonl'), + pass_at_k=10, + generations_per_sample=1) + + +@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 2]) +@pytest.mark.parametrize('prompt_string', ['Please code:\n', '']) +@pytest.mark.parametrize('generations_per_sample', [1, 3]) +def test_code_eval_task_dataloader(dataset_uri: str, tmp_path: Path, + num_fewshot: int, prompt_string: str, + generations_per_sample: int): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'mosaicml/mpt-7b') # type: ignore reportUnboundVariable + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 4 + seqlen = 2048 + + dl = get_icl_task_dataloader('code_evaluation', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter='\n', + continuation_delimiter='', + question_prelimiter='Code start: \n', + destination_path=str( + tmp_path / f'icl_{num_fewshot}.jsonl'), + generations_per_sample=generations_per_sample, + generation_kwargs={ + 'temperature': .9, + 'top_k': 40 + }) + assert isinstance(dl, DataSpec) + + assert isinstance(dl.dataloader, DataLoader) # pyright + batches = list(dl.dataloader) + dataset_size = len(open(dataset_uri, 'r').read().strip().split('\n')) + dataset_size *= generations_per_sample + + has_left_padding = [] + for i, batch in enumerate(batches): + max_prompt_length = 0 + if isinstance(dl.dataloader.dataset, InContextLearningCodeEvalDataset): + max_prompt_length = dl.dataloader.dataset.max_prompt_length + N = len(batches) + bs = batch_size if i < N - 1 else dataset_size - (N - 1) * batch_size + assert tuple(batch['input_ids'].shape) == (bs, max_prompt_length) + assert tuple(batch['attention_mask'].shape) == (bs, max_prompt_length) + assert batch['mode'] == 'generate' + # the maximum generation length from the small test data + assert batch['generation_kwargs']['max_new_tokens'] == 122 + has_left_padding.extend( + [item[0] == tokenizer.eos_token_id for item in batch['input_ids']]) + assert not all(has_left_padding) # longest should be pushed left + + decoded_batches = [ + tokenizer.batch_decode(batch['input_ids']) for batch in batches + ] + for decoded_batch in decoded_batches: + assert all( + item.count('Code start: \n') == num_fewshot + 1 + for item in decoded_batch) + + if len(prompt_string) > 0: + assert all( + item.count('Please code:\n') == 1 for item in decoded_batch) + + labels = [ + ' for idx, elem in enumerate(numbers):\n for idx2, elem2 in enumerate(numbers):\n if idx != idx2:\n distance = abs(elem - elem2)\n if distance < threshold:\n return True\n\n return False\n', + " result = []\n current_string = []\n current_depth = 0\n\n for c in paren_string:\n if c == '(':\n current_depth += 1\n current_string.append(c)\n elif c == ')':\n current_depth -= 1\n current_string.append(c)\n\n if current_depth == 0:\n result.append(''.join(current_string))\n current_string.clear()\n\n return result\n", + ' return number % 1.0\n', + ' balance = 0\n\n for op in operations:\n balance += op\n if balance < 0:\n return True\n\n return False\n', + ] + + # assert decoded_batch[0].endswith( + samples = [ + "Code start: \nfrom typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n \"\"\"\n", + "Code start: \nfrom typing import List\n\n\ndef separate_paren_groups(paren_string: str) -> List[str]:\n \"\"\" Input to this function is a string containing multiple groups of nested parentheses. Your goal is to\n separate those group into separate strings and return the list of those.\n Separate groups are balanced (each open brace is properly closed) and not nested within each other\n Ignore any spaces in the input string.\n >>> separate_paren_groups('( ) (( )) (( )( ))')\n ['()', '(())', '(()())']\n \"\"\"\n", + "Code start: \n\n\ndef truncate_number(number: float) -> float:\n \"\"\" Given a positive floating point number, it can be decomposed into\n and integer part (largest integer smaller than given number) and decimals\n (leftover part always smaller than 1).\n\n Return the decimal part of the number.\n >>> truncate_number(3.5)\n 0.5\n \"\"\"\n", + "Code start: \nfrom typing import List\n\n\ndef below_zero(operations: List[int]) -> bool:\n \"\"\" You're given a list of deposit and withdrawal operations on a bank account that starts with\n zero balance. Your task is to detect if at any point the balance of account fallls below zero, and\n at that point function should return True. Otherwise it should return False.\n >>> below_zero([1, 2, 3])\n False\n >>> below_zero([1, 2, -4, 5])\n True\n \"\"\"\n" + ] + for i in range(4): + for j in range(generations_per_sample): + k = i * generations_per_sample + j + b, n = divmod(k, batch_size) + assert batches[b]['labels'][n] == labels[i] + assert decoded_batches[b][n].endswith(samples[i]) + + +@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 1]) +def test_eval_split_batch(mpt_tokenizer: transformers.AutoTokenizer, + dataset_uri: str, num_fewshot: int, tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + tokenizer = mpt_tokenizer # type: ignore reportUnboundVariable + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 4 + seqlen = 512 + + dl = get_icl_task_dataloader('code_evaluation', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter='', + question_prelimiter='Code start: \n', + destination_path=str( + tmp_path / f'icl_{num_fewshot}.jsonl'), + generations_per_sample=1, + generation_kwargs={ + 'temperature': .9, + 'top_k': 40 + }) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + microbatch_size = 1 + microbatches = dl.split_batch(batch, microbatch_size) + assert len(microbatches) == 4 + for microbatch in microbatches: + assert dl.get_num_samples_in_batch(microbatch) == 1 + assert 'input_ids' in microbatch + # TODO: what should this be? + # assert tuple(microbatch['input_ids'].shape) == (microbatch_size, seqlen) + assert 'attention_mask' in microbatch + # assert tuple(microbatch['attention_mask'].shape) == (microbatch_size, seqlen) + assert isinstance(microbatch['generation_kwargs'], dict) + assert microbatch['generation_kwargs']['temperature'] == .9 + assert microbatch['generation_kwargs']['top_k'] == 40 + assert microbatch['generation_kwargs']['pad_token_id'] == 0 + assert microbatch['generation_kwargs']['num_beams'] == 1 + assert microbatch['generation_kwargs']['do_sample'] == True + assert microbatch['generation_kwargs']['use_cache'] == True + assert microbatch['generation_kwargs']['eos_token_id'] == 0 + + +@pytest.mark.parametrize('num_fewshot', [0, 5]) +@pytest.mark.parametrize('dataset_uri', ['lambada_small.jsonl']) +# @pytest.mark.gpu +# @pytest.mark.world_size(2) +def test_lm_task_evaluation(num_fewshot: int, dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, + tiny_gpt2_model: transformers.AutoModelForCausalLM): + + in_memory_logger = InMemoryLogger( + ) # track the logged metrics in the in_memory_logger + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + tokenizer = tiny_gpt2_tokenizer + batch_size = 2 + dl = get_icl_task_dataloader( + 'language_modeling', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter='', + destination_path=str(tmp_path / 'icl.jsonl'), + ) + + evaluator = Evaluator(label='lambada', + dataloader=dl, + metric_names=['InContextLearningLMAccuracy']) + + model = HuggingFaceModel( + model=tiny_gpt2_model, + tokenizer=tokenizer, + eval_metrics=[InContextLearningLMAccuracy()], + use_logits=True, + ) + + trainer = Trainer(model=model, max_duration='1ep', loggers=in_memory_logger) + trainer.eval(eval_dataloader=evaluator, subset_num_batches=2) + assert 'metrics/lambada/InContextLearningLMAccuracy' in in_memory_logger.data.keys( + ) + assert in_memory_logger.data['metrics/lambada/InContextLearningLMAccuracy'][ + 0][1].item() == 0 + + +@pytest.mark.parametrize('num_fewshot', [0, 5]) +@pytest.mark.parametrize('dataset_uri', ['winograd_small.jsonl']) +@pytest.mark.filterwarnings(r'ignore:Cannot split .* of length.*:UserWarning') +def test_schema_task_evaluation( + num_fewshot: int, dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path, + tiny_gpt2_model: transformers.AutoModelForCausalLM): + + in_memory_logger = InMemoryLogger( + ) # track the logged metrics in the in_memory_logger + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + tokenizer = tiny_gpt2_tokenizer + batch_size = 8 + dl = get_icl_task_dataloader( + 'schema', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(tmp_path / 'icl.jsonl'), + ) + + evaluator = Evaluator( + label='winograd', + dataloader=dl, + metric_names=['InContextLearningMultipleChoiceAccuracy']) + + model = HuggingFaceModel( + model=tiny_gpt2_model, + tokenizer=tokenizer, + eval_metrics=[InContextLearningMultipleChoiceAccuracy()], + use_logits=True, + ) + + trainer = Trainer(model=model, max_duration='1ba', loggers=in_memory_logger) + trainer.eval(eval_dataloader=evaluator) + assert 'metrics/winograd/InContextLearningMultipleChoiceAccuracy' in in_memory_logger.data.keys( + ) + assert in_memory_logger.data[ + 'metrics/winograd/InContextLearningMultipleChoiceAccuracy'][0][1].item( + ) > 0 + num_samples = 0 + with open(dataset_uri) as f: + for _ in f: + num_samples += 1 + assert trainer.state.eval_metrics['winograd'][ + 'InContextLearningMultipleChoiceAccuracy'].total == num_samples + + +@pytest.mark.parametrize('dataset_uri', ['mmlu_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 5]) +@pytest.mark.gpu +@pytest.mark.world_size(2) +@pytest.mark.filterwarnings(r'ignore:Cannot split .* of length.*:UserWarning') +def test_mc_task_evaluation_subcategories( + dataset_uri: str, num_fewshot: int, + tiny_gpt2_model: transformers.AutoModelForCausalLM, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path): + + in_memory_logger = InMemoryLogger( + ) # track the logged metrics in the in_memory_logger + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + tokenizer = tiny_gpt2_tokenizer + batch_size = 16 + max_seq_len = 64 + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + reproducibility.seed_all(1234) + dls = get_icl_task_dataloader('multiple_choice', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=max_seq_len, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str( + Path(gathered_paths[0]) / 'icl.jsonl'), + has_categories=True) + + assert isinstance(dls, dict) + evaluators = [ + Evaluator(label='mmlu/' + k, + dataloader=dl, + metric_names=['InContextLearningMultipleChoiceAccuracy']) + for k, dl in dls.items() + ] + + model = HuggingFaceModel( + model=tiny_gpt2_model, + tokenizer=tiny_gpt2_tokenizer, + eval_metrics=[InContextLearningMultipleChoiceAccuracy()], + use_logits=True, + ) + + trainer = Trainer(model=model, loggers=in_memory_logger) + trainer.eval(eval_dataloader=evaluators) + assert 'metrics/mmlu/computer_security/InContextLearningMultipleChoiceAccuracy' in in_memory_logger.data.keys( + ) + assert in_memory_logger.data[ + 'metrics/mmlu/computer_security/InContextLearningMultipleChoiceAccuracy'][ + 0][1].item() >= 0 + total = trainer.state.eval_metrics['mmlu/computer_security'][ + 'InContextLearningMultipleChoiceAccuracy'].total + dist.all_reduce(total) # type: ignore + assert total.item() == 4 # type: ignore + + +@pytest.mark.parametrize('dataset_uri', + ['piqa_small.jsonl', 'hellaswag_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 5]) +@pytest.mark.filterwarnings(r'ignore:Cannot split .* of length.*:UserWarning') +@pytest.mark.gpu +@pytest.mark.world_size(2) +def test_mc_task_evaluation(num_fewshot: int, dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, + tiny_gpt2_model: transformers.AutoModelForCausalLM): + + in_memory_logger = InMemoryLogger( + ) # track the logged metrics in the in_memory_logger + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + tokenizer = tiny_gpt2_tokenizer + batch_size = 8 + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + + # seed because the fewshot selection is currently unseeded + reproducibility.seed_all(1234) + dl = get_icl_task_dataloader( + 'multiple_choice', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=64, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + + evaluator = Evaluator( + label='mc', + dataloader=dl, + metric_names=['InContextLearningMultipleChoiceAccuracy']) + + model = HuggingFaceModel( + model=tiny_gpt2_model, + tokenizer=tiny_gpt2_tokenizer, + eval_metrics=[InContextLearningMultipleChoiceAccuracy()], + use_logits=True, + ) + + trainer = Trainer(model=model, max_duration='1ba', loggers=in_memory_logger) + trainer.eval(eval_dataloader=evaluator) + assert 'metrics/mc/InContextLearningMultipleChoiceAccuracy' in in_memory_logger.data.keys( + ) + assert in_memory_logger.data[ + 'metrics/mc/InContextLearningMultipleChoiceAccuracy'][0][1].item() >= 0 + num_samples = 0 + with open(dataset_uri) as f: + for _ in f: + num_samples += 1 + total = trainer.state.eval_metrics['mc'][ + 'InContextLearningMultipleChoiceAccuracy'].total + dist.all_reduce(total) # type: ignore + assert total.item() == num_samples # type: ignore + + +@pytest.mark.parametrize('num_fewshot', [0, 5]) +@pytest.mark.parametrize('dataset_uri', ['triviaqa_small.jsonl']) +@pytest.mark.filterwarnings( + r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning' +) +@pytest.mark.filterwarnings(r'ignore:Cannot split .* of length.*:UserWarning') +@pytest.mark.gpu +@pytest.mark.world_size(2) +def test_qa_task_evaluation_opt_tokenizer( + tiny_opt_tokenizer: transformers.AutoTokenizer, + tiny_opt_model: transformers.AutoModelForCausalLM, num_fewshot: int, + dataset_uri: str, tmp_path: Path): + + in_memory_logger = InMemoryLogger( + ) # track the logged metrics in the in_memory_logger + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + tokenizer = tiny_opt_tokenizer + + batch_size = 4 + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = get_icl_task_dataloader( + 'generation_task_with_answers', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + + evaluator = Evaluator( + label='triviaqa', + dataloader=dl, + metric_names=['InContextLearningGenerationExactMatchAccuracy']) + model = HuggingFaceModel( + model=tiny_opt_model, + tokenizer=tokenizer, + eval_metrics=[InContextLearningGenerationExactMatchAccuracy()], + use_logits=True, + ) + + trainer = Trainer(model=model, max_duration='1ba', loggers=in_memory_logger) + + trainer.eval(eval_dataloader=evaluator, subset_num_batches=2) + assert 'metrics/triviaqa/InContextLearningGenerationExactMatchAccuracy' in in_memory_logger.data.keys( + ) + assert in_memory_logger.data[ + 'metrics/triviaqa/InContextLearningGenerationExactMatchAccuracy'][0][ + 1].item() == 0 + + +@pytest.mark.parametrize('num_fewshot', [5]) +@pytest.mark.parametrize('dataset_uri', ['gsm8k_small.jsonl']) +@pytest.mark.gpu +@pytest.mark.world_size(2) +@pytest.mark.filterwarnings( + r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning' +) +@pytest.mark.filterwarnings(r'ignore:Cannot split .* of length.*:UserWarning') +def test_qa_task_evaluation_with_cot_opt_tokenizer( + tiny_opt_tokenizer: transformers.AutoTokenizer, + tiny_opt_model: transformers.AutoModelForCausalLM, num_fewshot: int, + dataset_uri: str, tmp_path: Path): + + in_memory_logger = InMemoryLogger( + ) # track the logged metrics in the in_memory_logger + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + tokenizer = tiny_opt_tokenizer + + batch_size = 4 + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = get_icl_task_dataloader( + 'generation_task_with_answers', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter="A: Let's think step by step. ", + cot_delimiter=' #### ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + + evaluator = Evaluator( + label='gsm8k', + dataloader=dl, + metric_names=['InContextLearningGenerationExactMatchAccuracy']) + model = HuggingFaceModel( + model=tiny_opt_model, + tokenizer=tokenizer, + eval_metrics=[InContextLearningGenerationExactMatchAccuracy()], + use_logits=True, + ) + + trainer = Trainer(model=model, max_duration='1ba', loggers=in_memory_logger) + + trainer.eval(eval_dataloader=evaluator, subset_num_batches=2) + assert 'metrics/gsm8k/InContextLearningGenerationExactMatchAccuracy' in in_memory_logger.data.keys( + ) + assert in_memory_logger.data[ + 'metrics/gsm8k/InContextLearningGenerationExactMatchAccuracy'][0][ + 1].item() == 0 + + +@pytest.mark.parametrize('dataset_uri', ['triviaqa_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 5]) +@pytest.mark.gpu +@pytest.mark.world_size(2) +@pytest.mark.filterwarnings( + r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning' +) +def test_qa_task_evaluation(num_fewshot: int, dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tiny_gpt2_model: transformers.AutoModelForCausalLM, + tmp_path: Path): + + in_memory_logger = InMemoryLogger( + ) # track the logged metrics in the in_memory_logger + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + tokenizer = tiny_gpt2_tokenizer + batch_size = 2 + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = get_icl_task_dataloader( + 'generation_task_with_answers', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + + evaluator = Evaluator( + label='triviaqa', + dataloader=dl, + metric_names=['InContextLearningGenerationExactMatchAccuracy']) + + model = HuggingFaceModel( + model=tiny_gpt2_model, + tokenizer=tiny_gpt2_tokenizer, + eval_metrics=[InContextLearningGenerationExactMatchAccuracy()], + use_logits=True, + ) + + trainer = Trainer(model=model, max_duration='1ba', loggers=in_memory_logger) + + trainer.eval(eval_dataloader=evaluator, subset_num_batches=2) + assert 'metrics/triviaqa/InContextLearningGenerationExactMatchAccuracy' in in_memory_logger.data.keys( + ) + assert in_memory_logger.data[ + 'metrics/triviaqa/InContextLearningGenerationExactMatchAccuracy'][0][ + 1].item() == 0 + + +@pytest.mark.parametrize('dataset_uri', ['gsm8k_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [5]) +@pytest.mark.filterwarnings( + r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning' +) +@pytest.mark.gpu +@pytest.mark.world_size(2) +def test_qa_task_with_cot_evaluation( + num_fewshot: int, dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tiny_gpt2_model: transformers.AutoModelForCausalLM, tmp_path: Path): + + in_memory_logger = InMemoryLogger( + ) # track the logged metrics in the in_memory_logger + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + tokenizer = tiny_gpt2_tokenizer + batch_size = 2 + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = get_icl_task_dataloader( + 'generation_task_with_answers', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter="A: Let's think step by step", + cot_delimiter=' #### ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + + evaluator = Evaluator( + label='gsm8k', + dataloader=dl, + metric_names=['InContextLearningGenerationExactMatchAccuracy']) + + model = HuggingFaceModel( + model=tiny_gpt2_model, + tokenizer=tiny_gpt2_tokenizer, + eval_metrics=[InContextLearningGenerationExactMatchAccuracy()], + use_logits=True, + ) + + trainer = Trainer(model=model, max_duration='1ba', loggers=in_memory_logger) + + trainer.eval(eval_dataloader=evaluator, subset_num_batches=2) + assert 'metrics/gsm8k/InContextLearningGenerationExactMatchAccuracy' in in_memory_logger.data.keys( + ) + assert in_memory_logger.data[ + 'metrics/gsm8k/InContextLearningGenerationExactMatchAccuracy'][0][ + 1].item() == 0 + + +def test_code_eval_requires_envvar(monkeypatch: pytest.MonkeyPatch): + monkeypatch.delenv('CODE_EVAL_DEVICE', raising=False) + with pytest.raises( + ValueError, + match='Attempting to use InContextLearningCodeEvalAccuracy but.*'): + InContextLearningCodeEvalAccuracy().get_client() + + +def test_code_eval_requires_valid_envvar(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv('CODE_EVAL_DEVICE', 'bigchungus') + with pytest.raises( + ValueError, + match='Environment variable `CODE_EVAL_DEVICE` must be on.*'): + InContextLearningCodeEvalAccuracy().get_client() + + +@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0]) +@pytest.mark.parametrize('generations_per_sample', range(1, 3)) +@pytest.mark.gpu +@pytest.mark.world_size(2) +@pytest.mark.filterwarnings( + r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning' +) +def test_code_eval_microbatching( + monkeypatch: pytest.MonkeyPatch, + tiny_opt_tokenizer: transformers.AutoTokenizer, + tiny_opt_model: transformers.AutoModelForCausalLM, num_fewshot: int, + dataset_uri: str, tmp_path: Path, generations_per_sample: int): + + monkeypatch.setenv('CODE_EVAL_DEVICE', 'LOCAL') + in_memory_logger = InMemoryLogger( + ) # track the logged metrics in the in_memory_logger + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + tokenizer = tiny_opt_tokenizer + batch_size = 4 + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = get_icl_task_dataloader( + 'code_evaluation', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=150, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + generations_per_sample=generations_per_sample, + ) + + evaluator = Evaluator(label='humaneval', + dataloader=dl, + metric_names=['InContextLearningCodeEvalAccuracy'], + device_eval_microbatch_size=1) + model = HuggingFaceModel( + model=tiny_opt_model, + tokenizer=tokenizer, + eval_metrics=[InContextLearningCodeEvalAccuracy()], + use_logits=True, + ) + + trainer = Trainer(model=model, max_duration='1ba', loggers=in_memory_logger) + torch.use_deterministic_algorithms(False) + trainer.eval(eval_dataloader=evaluator) + torch.use_deterministic_algorithms(True) + assert 'metrics/humaneval/InContextLearningCodeEvalAccuracy' in in_memory_logger.data.keys( + ) + assert in_memory_logger.data[ + 'metrics/humaneval/InContextLearningCodeEvalAccuracy'][0][1].item() == 0 + + +@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0]) +@pytest.mark.parametrize('generations_per_sample', range(1, 3)) +@pytest.mark.gpu +@pytest.mark.world_size(2) +@pytest.mark.filterwarnings( + r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning' +) +def test_code_eval_sentpiece_evaluation( + monkeypatch: pytest.MonkeyPatch, num_fewshot: int, dataset_uri: str, + tiny_opt_tokenizer: transformers.AutoTokenizer, + tiny_opt_model: transformers.AutoModelForCausalLM, tmp_path: Path, + generations_per_sample: int): + + monkeypatch.setenv('CODE_EVAL_DEVICE', 'LOCAL') + in_memory_logger = InMemoryLogger( + ) # track the logged metrics in the in_memory_logger + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + tokenizer = tiny_opt_tokenizer + batch_size = 2 + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = get_icl_task_dataloader( + 'code_evaluation', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=175, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + generations_per_sample=generations_per_sample, + ) + + evaluator = Evaluator(label='humaneval', + dataloader=dl, + metric_names=['InContextLearningCodeEvalAccuracy']) + model = HuggingFaceModel( + model=tiny_opt_model, + tokenizer=tiny_opt_tokenizer, + eval_metrics=[InContextLearningCodeEvalAccuracy()], + use_logits=True, + ) + + trainer = Trainer(model=model, max_duration='1ba', loggers=in_memory_logger) + torch.use_deterministic_algorithms(False) + trainer.eval(eval_dataloader=evaluator) + torch.use_deterministic_algorithms(True) + assert 'metrics/humaneval/InContextLearningCodeEvalAccuracy' in in_memory_logger.data.keys( + ) + assert in_memory_logger.data[ + 'metrics/humaneval/InContextLearningCodeEvalAccuracy'][0][1].item() == 0 + + +@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 2]) +@pytest.mark.parametrize('generations_per_sample', [1]) +@pytest.mark.filterwarnings(r'ignore: Input length of input_ids is') +@pytest.mark.gpu +@pytest.mark.world_size(2) +@pytest.mark.filterwarnings( + r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning' +) +def test_code_eval_task_evaluation( + monkeypatch: pytest.MonkeyPatch, num_fewshot: int, dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tiny_gpt2_model: transformers.AutoModelForCausalLM, tmp_path: Path, + generations_per_sample: int): + + monkeypatch.setenv('CODE_EVAL_DEVICE', 'LOCAL') + in_memory_logger = InMemoryLogger( + ) # track the logged metrics in the in_memory_logger + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + tokenizer = tiny_gpt2_tokenizer + batch_size = 2 + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = get_icl_task_dataloader( + 'code_evaluation', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=64 * num_fewshot, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + generations_per_sample=generations_per_sample, + ) + + evaluator = Evaluator(label='humaneval', + dataloader=dl, + metric_names=['InContextLearningCodeEvalAccuracy']) + model = HuggingFaceModel( + model=tiny_gpt2_model, + tokenizer=tiny_gpt2_tokenizer, + eval_metrics=[InContextLearningCodeEvalAccuracy()], + use_logits=True, + ) + + trainer = Trainer(model=model, max_duration='1ba', loggers=in_memory_logger) + torch.use_deterministic_algorithms(False) + trainer.eval(eval_dataloader=evaluator) + torch.use_deterministic_algorithms(True) + assert 'metrics/humaneval/InContextLearningCodeEvalAccuracy' in in_memory_logger.data.keys( + ) + assert in_memory_logger.data[ + 'metrics/humaneval/InContextLearningCodeEvalAccuracy'][0][1].item() == 0 + + +@pytest.mark.parametrize('dataset_uri', ['lambada_small.jsonl']) +def test_lm_spacing_dataloader(dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_gpt2_tokenizer + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 2 + seqlen = 512 + dl = get_icl_task_dataloader('language_modeling', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=1, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=' UNIQUE ', + destination_path=str(tmp_path / 'icl.jsonl')) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) # pyright + first_batch = next(dl.dataloader._get_iterator()) + second_batch = next(dl.dataloader._get_iterator()) + + first_batch_text = tokenizer.decode(first_batch['input_ids'][0], + skip_special_tokens=True) + second_batch_text = tokenizer.decode(second_batch['input_ids'][0], + skip_special_tokens=True) + + first_batch_without_last_word = ' '.join(first_batch_text.split(' ')[:-1]) + second_batch_without_last_word = ' '.join(second_batch_text.split(' ')[:-1]) + + assert first_batch_without_last_word.endswith(' UNIQUE') + assert second_batch_without_last_word.endswith(' UNIQUE') + + assert first_batch_without_last_word.count(' UNIQUE ') == 1 + assert second_batch_without_last_word.count(' UNIQUE ') == 1 + + +@pytest.mark.parametrize('dataset_uri', ['hf://mosaicml/test_dataset']) +@pytest.mark.parametrize('num_fewshot', [0, 1]) +@pytest.mark.parametrize('prompt_string', ['Complete the voiceline: ', '']) +@pytest.mark.parametrize('hf_loading_vars', [{ + 'split': 'test', + 'name': 'juggernaut', +}]) +@pytest.mark.parametrize( + 'hf_parsing_map', + [None, { + 'context': ['context'], + 'continuation': ['continuation'] + }]) +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' +) +def test_hf_dataloading_lm_dataloader( + dataset_uri: str, tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, num_fewshot: int, prompt_string: str, + hf_loading_vars: Dict[str, + str], hf_parsing_map: Optional[Dict[str, + List[str]]]): + + tokenizer = tiny_gpt2_tokenizer + batch_size = 2 + seqlen = 2048 + dl = get_icl_task_dataloader( + 'language_modeling', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=' ', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + + assert 'input_ids' in batch + assert tuple(batch['input_ids'].shape) == (batch_size, seqlen) + assert 'attention_mask' in batch + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) + assert 'continuation_indices' in batch + assert isinstance(batch['continuation_indices'], list) and len( + batch['continuation_indices']) == batch_size + assert 'mode' in batch + assert batch['mode'] == 'icl_task' + min_idx = min(batch['continuation_indices'][0]).item() + max_idx = max(batch['continuation_indices'][0]).item() + assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + + 1]) == ' and me.' + + decoded_batch = [ + tokenizer.decode(row[row != tokenizer.eos_token_id]) + for row in batch['input_ids'] + ] + assert decoded_batch[0] == "Looks like it's just you and me." + assert decoded_batch[ + 1] == "There's a fine line between bravery and stupidity." + + +@pytest.mark.parametrize('dataset_uri', ['hf://mosaicml/test_dataset']) +@pytest.mark.parametrize('num_fewshot', [0, 1]) +@pytest.mark.parametrize('prompt_string', ['What spell does this invoke? ', '']) +@pytest.mark.parametrize('hf_loading_vars', [{ + 'split': 'test', + 'name': 'invoker', +}]) +@pytest.mark.parametrize('hf_parsing_map', [{ + 'context': ['quas', 'wex', 'exort'], + 'answer': ['spell'] +}]) +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' +) +def test_hf_dataloading_custom_parsing( + dataset_uri: str, tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, num_fewshot: int, prompt_string: str, + hf_loading_vars: Dict[str, str], hf_parsing_map: Dict[str, List[str]]): + + tokenizer = tiny_gpt2_tokenizer + batch_size = 2 + seqlen = 2048 + + # empirical number from the small test dataset + maximum_answer_length = 4 + + dl = get_icl_task_dataloader( + 'generation_task_with_answers', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter='\n', + question_prelimiter='Orbs: ', + continuation_delimiter='\nSpell:', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + + assert tuple(batch['input_ids'].shape) == (batch_size, + seqlen - maximum_answer_length) + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen - + maximum_answer_length) + assert batch['mode'] == 'generate' + # the maximum generation length from the small test data + assert batch['generation_kwargs']['max_new_tokens'] == maximum_answer_length + assert all(item[0] == tokenizer.eos_token_id for item in batch['input_ids']) + + decoded_batch = tokenizer.batch_decode(batch['input_ids']) + assert all( + item.count('Orbs: ') == num_fewshot + 1 for item in decoded_batch) + assert all( + item.count('\nSpell:') == num_fewshot + 1 for item in decoded_batch) + + if len(prompt_string) > 0: + assert all( + item.count('What spell does this invoke? ') == 1 + for item in decoded_batch) + assert all( + set(found) == set(expected) for found, expected in zip( + batch['labels'], [['defeaning blast'], ['cold snap']])) + assert decoded_batch[0].endswith('Orbs: quas wex exort\nSpell:') + assert decoded_batch[1].endswith('Orbs: quas quas quas\nSpell:') diff --git a/tests/eval/test_nlp_metrics.py b/tests/eval/test_nlp_metrics.py new file mode 100644 index 0000000000..344d642715 --- /dev/null +++ b/tests/eval/test_nlp_metrics.py @@ -0,0 +1,196 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, List + +import pytest +import torch +import transformers + +from llmfoundry.eval.metrics import ( + InContextLearningCodeEvalAccuracy, + InContextLearningGenerationExactMatchAccuracy, InContextLearningLMAccuracy, + InContextLearningMultipleChoiceAccuracy) + + +def test_in_context_learning_lm_accuracy( + tiny_gpt2_tokenizer: transformers.AutoTokenizer): + contexts = ['The dog is', 'I love to eat', 'I hate', 'The weather is'] + continuations = [' furry', ' pie', ' long lines', ' snowy'] + pad = tiny_gpt2_tokenizer.pad_token_id + inputs = [ + tiny_gpt2_tokenizer(context)['input_ids'] + + tiny_gpt2_tokenizer(continuation)['input_ids'] + for context, continuation in zip(contexts, continuations) + ] + inputs = torch.tensor( + [input + [pad] * (2048 - len(input)) for input in inputs]) + + cont_idxs = [] + for context, continuation in zip(contexts, continuations): + start = len(tiny_gpt2_tokenizer(context)['input_ids']) + end = start + len(tiny_gpt2_tokenizer(continuation)['input_ids']) + cont_idxs.append(torch.tensor(list(range(start, end)))) + + batch = { + 'continuation_indices': cont_idxs, + 'labels': inputs.roll(-1), + 'input_ids': inputs + } + logits = torch.nn.functional.one_hot(inputs.roll(-1), + num_classes=pad + 1).float() * 100 + start, end = cont_idxs[1].tolist()[0] - 1, cont_idxs[1].tolist()[-1] + logits[1][start:end] = logits[0][start:end].clone( + ) # make one of the answer's continuations incorrect + metric = InContextLearningLMAccuracy() + metric.update(batch, logits, batch['labels']) + + assert metric.compute() == 0.75 + + +def test_in_context_learning_qa_accuracy(): + outputs = [ + 'Correct but then some more text', 'Incorrect', + ' the CORREct with weird casing and spacing' + ] + labels = [['Correct'], ['blah', 'blah2'], ['blah', 'correct']] + batch = {'cot_delimiter': '', 'labels': labels} + metric = InContextLearningGenerationExactMatchAccuracy() + metric.update(batch, outputs, labels) + + assert metric.compute() == (2 / 3) + + +def test_in_context_learning_qa_cot_accuracy(): + outputs = [ + 'chain of thought ### Correct but then some more text\n\nanother chain of thought ### Incorrect answer this time', + 'Incorrect', + 'chain of thought ### the CORREct with weird casing and spacing', + 'incorrect chain of thought delimiter ## Correct but wrong delimiter' + ] + labels = [['Correct'], ['blah', 'blah2'], ['blah', 'correct'], ['correct']] + batch = { + 'cot_delimiter': ' ### ', + 'labels': labels, + 'do_normalization': True, + 'stopping_criteria': '\n\n' + } + metric = InContextLearningGenerationExactMatchAccuracy() + metric.update(batch, outputs, labels) + + assert metric.compute() == (2 / 4) + + +def test_in_context_learning_code_eval_accuracy( + monkeypatch: pytest.MonkeyPatch): + outputs = [ + ' return 1 if n <= 1 else fib(n - 1) + fib(n - 1)', # incorrect + ' if n <= 1:\n return 1\n return fib(n-1) + fib(n-2)', # incorrect spacing + ' return n * 2', # correct + ' return 2*n', # correct + ' return n + 2', # incorrect + ' return n + 1' + ] # correct + labels = [] + prompts = [ + 'def fib(n):\n', 'def multiply_by_two(n):\n', 'def add_one(n):\n' + ] + entry_points = ['fib', 'multiply_by_two', 'add_one'] + test_inputs = [['(1,)', '(2,)', '(4,)'], ['(1,)', '(2,)', '(4,)'], + ['(1,)', '(2,)', '(4,)']] + test_outputs = [['1', '2', '5'], ['2', '4', '8'], ['2', '3', '5']] + sample_ids = [0, 1, 2] + languages = ['python', 'python', 'python'] + monkeypatch.setenv('CODE_EVAL_DEVICE', 'LOCAL') + generations_per_sample = 2 + + def repeat(values: List[Any]): + return [val for val in values for _ in range(generations_per_sample)] + + transformers = pytest.importorskip('transformers') + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'mosaicml/mpt-7b') # type: ignore reportUnboundVariable + tokenizer.pad_token = tokenizer.eos_token + input_ids = tokenizer.batch_encode_plus(repeat(prompts), + return_tensors='pt', + padding=True)['input_ids'] + batch = { + # This tests deterministic beam search rather than sampling + 'input_ids': input_ids, + 'generation_kwargs': { + 'num_beams': 1, + }, + 'prompts': repeat(prompts), + 'pass_at_k': [1], + 'entry_points': repeat(entry_points), + 'test_inputs': repeat(test_inputs), + 'test_outputs': repeat(test_outputs), + 'languages': repeat(languages), + 'dataset_size': len(prompts), + 'generations_per_sample': generations_per_sample, + 'sample_id': repeat(sample_ids), + } + metric = InContextLearningCodeEvalAccuracy() + metric.update(batch, outputs, labels) + + # pass@1 values + # program 1: 0 + # program 2: 1 + # program 3: .5 + # mean: 0.5 + assert metric.compute() == 0.5 + + +def test_in_context_learning_mc_accuracy( + tiny_gpt2_tokenizer: transformers.AutoTokenizer): + contexts = [ + 'Q: How do you cook a cake?', 'Q: How do you cook a cake?', + 'Q: How old is the earth?', 'Q: How old is the earth?' + ] + continuations = [ + ' A: turn on the oven', ' A: do a backflip', ' A: 2 minutes', + ' A: 4.5 billion years' + ] + gold_indices = [0, 1] + choice_groupings = [(0, 2), (2, 4)] + pad = tiny_gpt2_tokenizer.pad_token_id + inputs = [ + tiny_gpt2_tokenizer(context)['input_ids'] + + tiny_gpt2_tokenizer(continuation)['input_ids'] + for context, continuation in zip(contexts, continuations) + ] + inputs = torch.tensor( + [input + [pad] * (2048 - len(input)) for input in inputs]) + attention_mask = ~(inputs == pad) + + cont_idxs = [] + for context, continuation in zip(contexts, continuations): + start = len(tiny_gpt2_tokenizer(context)['input_ids']) + end = start + len(tiny_gpt2_tokenizer(continuation)['input_ids']) + cont_idxs.append(torch.tensor(list(range(start, end)))) + + batch = { + 'continuation_indices': cont_idxs, + 'labels': inputs.roll(-1), + 'input_ids': inputs, + 'attention_mask': attention_mask, + 'gold_indices': gold_indices, + 'choice_groupings': choice_groupings + } + logits = torch.nn.functional.one_hot(inputs.roll(-1), + num_classes=pad + 1).float() + + # for the first two, the correct answer is continuation 0 + # make the answer correct by making continuation 0 more likely for both answers + start, end = cont_idxs[1].tolist()[0] - 1, cont_idxs[1].tolist()[-1] + logits[1][start:end] = logits[0][start:end].clone() + + # for the last two, the correct answer is continuation 3 + # make the answer incorrect by making continuation 2 more likely for both answers + start, end = cont_idxs[3].tolist()[0], cont_idxs[3].tolist()[-1] + logits[3][start:end] = logits[2][start:end].clone() + + metric = InContextLearningMultipleChoiceAccuracy() + + metric.update(batch, logits, batch['labels']) + assert metric.compute() == 0.5 diff --git a/tests/fixtures/autouse.py b/tests/fixtures/autouse.py index ccbe1b69f7..16e3f8ad6f 100644 --- a/tests/fixtures/autouse.py +++ b/tests/fixtures/autouse.py @@ -9,11 +9,19 @@ import torch from composer.utils import dist, get_device, reproducibility +from llmfoundry.utils.registry_utils import save_registry + # Add llm-foundry repo root to path so we can import scripts in the tests REPO_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) sys.path.append(REPO_DIR) +@pytest.fixture(autouse=True) +def save_registry_fixture(): + with save_registry(): + yield + + @pytest.fixture(autouse=True) def initialize_dist(request: pytest.FixtureRequest): """Initialize the default PyTorch distributed process group for tests.""" diff --git a/tests/fixtures/models.py b/tests/fixtures/models.py index e4e6892fe3..616d66085c 100644 --- a/tests/fixtures/models.py +++ b/tests/fixtures/models.py @@ -1,8 +1,10 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import copy from typing import Any, Callable +import pytest from omegaconf import DictConfig from pytest import fixture from transformers import PreTrainedTokenizerBase @@ -71,3 +73,125 @@ def build(**kwargs: Any) -> ComposerHFCausalLM: return model return build + + +def tiny_gpt2_model_helper(config): # type: ignore + transformers = pytest.importorskip('transformers') + + return transformers.AutoModelForCausalLM.from_config(config) + + +@pytest.fixture(scope='session') +def _session_tiny_gpt2_model(_session_tiny_gpt2_config): # type: ignore + return tiny_gpt2_model_helper(_session_tiny_gpt2_config) + + +def tiny_gpt2_config_helper(): + transformers = pytest.importorskip('transformers') + + tiny_overrides = { + 'n_embd': 2, + 'n_head': 2, + 'n_layer': 2, + 'vocab_size': 50258 # 50257 + 1 for pad token + } + return transformers.AutoConfig.from_pretrained('gpt2', **tiny_overrides) + + +@pytest.fixture(scope='session') +def _session_tiny_gpt2_config(): # type: ignore + return tiny_gpt2_config_helper() + + +def tiny_gpt2_tokenizer_helper(): + transformers = pytest.importorskip('transformers') + + hf_tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2') + hf_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) + return hf_tokenizer + + +@pytest.fixture +def tiny_gpt2_model(_session_tiny_gpt2_model): # type: ignore + return copy.deepcopy(_session_tiny_gpt2_model) + + +@pytest.fixture(scope='session') +def _session_tiny_gpt2_tokenizer(): # type: ignore + return tiny_gpt2_tokenizer_helper() + + +@pytest.fixture +def tiny_gpt2_tokenizer(_session_tiny_gpt2_tokenizer): # type: ignore + return copy.deepcopy(_session_tiny_gpt2_tokenizer) + + +def tiny_llama_tokenizer_helper(): + transformers = pytest.importorskip('transformers') + + hf_tokenizer = transformers.AutoTokenizer.from_pretrained( + 'huggyllama/llama-7b', use_fast=False) + return hf_tokenizer + + +@pytest.fixture(scope='session') +def _session_tiny_llama_tokenizer(): # type: ignore + return tiny_llama_tokenizer_helper() + + +@pytest.fixture +def tiny_llama_tokenizer(_session_tiny_llama_tokenizer): # type: ignore + return copy.deepcopy(_session_tiny_llama_tokenizer) + + +def tiny_opt_tokenizer_helper(): + transformers = pytest.importorskip('transformers') + + hf_tokenizer = transformers.AutoTokenizer.from_pretrained( + 'facebook/opt-125m') + hf_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) + return hf_tokenizer + + +def tiny_opt_model_helper(config): # type: ignore + transformers = pytest.importorskip('transformers') + + return transformers.AutoModelForCausalLM.from_config(config) + + +@pytest.fixture(scope='session') +def _session_tiny_opt_tokenizer(): # type: ignore + return tiny_opt_tokenizer_helper() + + +@pytest.fixture(scope='session') +def _session_tiny_opt_config(): # type: ignore + return tiny_opt_config_helper() + + +@pytest.fixture(scope='session') +def _session_tiny_opt_model(_session_tiny_opt_config): # type: ignore + return tiny_opt_model_helper(_session_tiny_opt_config) + + +def tiny_opt_config_helper(): + transformers = pytest.importorskip('transformers') + + tiny_overrides = { + 'n_embd': 2, + 'n_head': 2, + 'n_layer': 2, + 'vocab_size': 50272 + } + return transformers.AutoConfig.from_pretrained('facebook/opt-125m', + **tiny_overrides) + + +@pytest.fixture +def tiny_opt_tokenizer(_session_tiny_opt_tokenizer): # type: ignore + return copy.deepcopy(_session_tiny_opt_tokenizer) + + +@pytest.fixture +def tiny_opt_model(_session_tiny_opt_model): # type: ignore + return copy.deepcopy(_session_tiny_opt_model) diff --git a/tests/models/layers/test_dmoe.py b/tests/models/layers/test_dmoe.py index 9c15745793..c8e7ec3e67 100644 --- a/tests/models/layers/test_dmoe.py +++ b/tests/models/layers/test_dmoe.py @@ -239,6 +239,10 @@ def test_fwd_equal_dmoe(seqlen: int, precision: str, mlp_type: str): torch_dmoe_config = copy.deepcopy(mb_dmoe_config) torch_dmoe_config.ffn_config['ffn_type'] = 'torch_dmoe' + del torch_dmoe_config.ffn_config['moe_world_size'] + del torch_dmoe_config.ffn_config['fc_type'] + del torch_dmoe_config.ffn_config['moe_loss_weight'] + del torch_dmoe_config.ffn_config['return_bias'] mb_dmoe_model = MPTForCausalLM(mb_dmoe_config).to(device=device, dtype=dtype) diff --git a/tests/models/layers/test_flash_torch.py b/tests/models/layers/test_flash_torch.py index c0e9f4b3b5..f212665c93 100644 --- a/tests/models/layers/test_flash_torch.py +++ b/tests/models/layers/test_flash_torch.py @@ -8,6 +8,7 @@ from llmfoundry.models.layers import attention from llmfoundry.models.layers.attention import (check_alibi_support, gen_slopes, is_flash_v2_installed) +from llmfoundry.models.layers.layer_builders import build_attention_layer from llmfoundry.models.mpt.modeling_mpt import (apply_sequence_id, gen_attention_mask_in_length, gen_flash_attn_padding_info, @@ -120,9 +121,15 @@ def test_attn_impl(attn_impl_0: str, ]).to(device=device) cfg.attn_impl = attn_impl_0 - attn0 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device) + attn0 = build_attention_layer( + name=attn_type, + attn_kwargs=om.to_container(cfg), # type: ignore + ).to(device) cfg.attn_impl = attn_impl_1 - attn1 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device) + attn1 = build_attention_layer( + name=attn_type, + attn_kwargs=om.to_container(cfg), # type: ignore + ).to(device) attn1.load_state_dict(attn0.state_dict()) diff --git a/tests/models/test_rope_dail_vs_hf.py b/tests/models/test_rope_dail_vs_hf.py index 33c3d3c052..b9ab90357a 100644 --- a/tests/models/test_rope_dail_vs_hf.py +++ b/tests/models/test_rope_dail_vs_hf.py @@ -7,6 +7,7 @@ from omegaconf import OmegaConf as om from llmfoundry.models.layers.attention import is_flash_v2_installed +from llmfoundry.models.layers.layer_builders import build_attention_layer from llmfoundry.models.mpt.modeling_mpt import (gen_flash_attn_padding_info, gen_rotary_embedding) @@ -21,8 +22,6 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): if not is_flash_v2_installed(): pytest.skip('dail implementation of rope requires flash attention 2.') - from llmfoundry.models.layers import attention - cfg = om.create({ 'attn_impl': 'flash', 'd_model': 128, @@ -37,8 +36,16 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): if attn_type == 'grouped_query_attention': cfg.kv_n_heads = 2 - attn0 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device) - attn1 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device) + attn0 = build_attention_layer( + name=attn_type, + attn_kwargs=om.to_container( + cfg), # type: ignore (to_container return broad type) + ).to(device) + attn1 = build_attention_layer( + name=attn_type, + attn_kwargs=om.to_container( + cfg), # type: ignore (to_container return broad type) + ).to(device) attn1.load_state_dict(attn0.state_dict()) x0 = torch.randn(batch_size, seq_len, cfg.d_model).to(device) diff --git a/tests/models/utils/test_param_init_fns.py b/tests/models/utils/test_param_init_fns.py index 6be2c5ca42..0efc245602 100644 --- a/tests/models/utils/test_param_init_fns.py +++ b/tests/models/utils/test_param_init_fns.py @@ -12,7 +12,8 @@ from omegaconf import OmegaConf as om from torch import nn -from llmfoundry.models.utils import MODEL_INIT_REGISTRY, generic_param_init_fn_ +from llmfoundry.layers_registry import param_init_fns +from llmfoundry.models.utils import generic_param_init_fn_ class MLP(nn.Module): @@ -150,7 +151,7 @@ def test_emb_init(emb_init_cfg: Optional[Tuple[str, Union[int, List[int]]]]): bias=True)), ])) - model.apply(partial(MODEL_INIT_REGISTRY['kaiming_normal_'], **dict_cfg)) + model.apply(partial(param_init_fns.get('kaiming_normal_'), **dict_cfg)) assert isinstance(model.emb, torch.nn.Embedding) diff --git a/tests/test_registry.py b/tests/test_registry.py index c93c7c9749..d7a1fc7dfe 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -31,6 +31,14 @@ def test_expected_registries_exist(): 'metrics', 'models', 'norms', + 'param_init_fns', + 'module_init_fns', + 'ffns', + 'ffns_with_norm', + 'ffns_with_megablocks', + 'attention_classes', + 'attention_implementations', + 'fcs', } assert existing_registries == expected_registry_names