diff --git a/composer/datasets/__init__.py b/composer/datasets/__init__.py index 56672e27f3..a456ec3239 100644 --- a/composer/datasets/__init__.py +++ b/composer/datasets/__init__.py @@ -11,6 +11,11 @@ build_streaming_cifar10_dataloader, build_synthetic_cifar10_dataloader) from composer.datasets.imagenet import (build_ffcv_imagenet_dataloader, build_imagenet_dataloader, build_streaming_imagenet1k_dataloader, build_synthetic_imagenet_dataloader) +from composer.datasets.in_context_learning_evaluation import (InContextLearningCodeEvalDataset, + InContextLearningDataset, InContextLearningLMTaskDataset, + InContextLearningMultipleChoiceTaskDataset, + InContextLearningQATaskDataset, + InContextLearningSchemaTaskDataset) from composer.datasets.lm_dataset import build_lm_dataloader from composer.datasets.mnist import build_mnist_dataloader, build_synthetic_mnist_dataloader from composer.datasets.synthetic import (SyntheticBatchPairDataset, SyntheticDataLabelType, SyntheticDataType, @@ -24,6 +29,12 @@ 'SyntheticDataLabelType', 'SyntheticDataType', 'SyntheticPILDataset', + 'InContextLearningDataset', + 'InContextLearningQATaskDataset', + 'InContextLearningLMTaskDataset', + 'InContextLearningCodeEvalDataset', + 'InContextLearningMultipleChoiceTaskDataset', + 'InContextLearningSchemaTaskDataset', 'build_ade20k_dataloader', 'build_streaming_ade20k_dataloader', 'build_streaming_c4_dataloader', diff --git a/composer/datasets/in_context_learning_evaluation.py b/composer/datasets/in_context_learning_evaluation.py index 294bb1b2ba..4e0e30f1ff 100644 --- a/composer/datasets/in_context_learning_evaluation.py +++ b/composer/datasets/in_context_learning_evaluation.py @@ -4,15 +4,14 @@ from __future__ import annotations +import copy import json import os import random -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union import torch -import transformers from torch.utils.data import DataLoader, Dataset -from tqdm import tqdm from composer.core import DataSpec from composer.core.data_spec import _default_split_batch, _split_list @@ -21,6 +20,7 @@ if TYPE_CHECKING: import transformers + from datasets import Dataset as HFDataset # pyright: ignore[reportGeneralTypeIssues] # Allow models to have slightly more tokens than were used in the most verbose CoT in the dataset _MAX_ANSWER_BUFFER_LENGTH = 10 @@ -34,35 +34,106 @@ ] -def strip_data(samples): - return [{k: v.strip() if isinstance(v, str) else v for k, v in entry.items()} for entry in samples] +def strip_data(example: Dict) -> Dict: + """ + Remove white space from the begging and end of string values in a dictionary + + Args: + example: Dictionary to be stripped + + Returns: + dict: The same dictionary with .strip() applied to any value in the dict that is a string + """ + return {k: v.strip() if isinstance(v, str) else v for k, v in example.items()} + + +def _tokenizer_needs_prefix_space(tokenizer: transformers.PreTrainedTokenizerBase) -> bool: + """ + Test for whether a prefix space is needed before the continuation. + Sentencepiece tokenization should not have a prefix space, but gpt2 style BPE should. + + Args: + tokenizer: Tokenizer to test + + Returns: + bool: Whether or not the tokenizer needs a prefix space + """ + test_tokens = tokenizer(' a', add_special_tokens=False)['input_ids'] + assert isinstance(test_tokens, list) + return len(test_tokens) == 1 -def _tokenizer_needs_prefix_space(tokenizer) -> bool: - # Test for whether a prefix space is needed before the continuation. - # sentencepiece tokenization should not have a prefix space, but gpt2 style BPE should - return len(tokenizer(' a', add_special_tokens=False)['input_ids']) == 1 +def _trim_context(context_enc: List, continuation_enc: List, max_seq_len: int) -> List: + """ + Trims a list of tokens down to `max_seq_len` if the length of the list plus the continuation + is more than `max_seq_len`. It will always trim tokens from the left, i.e. tokens at the beginning + of the context will be removed. + Args: + context_enc (list): List of tokens in the context + continuation_enc (lsit): List of tokens in the continuation + max_seq_len (int): Maximum length the model can ingest -def _make_padded_input(context_enc, continuation_enc, max_seq_len, pad_tok_id, padding_side='right'): + Returns: + list: The encoded context trimmed from the left + """ if len(continuation_enc) + len(context_enc) > max_seq_len: - # clip from the end context_max_subseq_len = max_seq_len - len(continuation_enc) if context_max_subseq_len < 0: - raise Exception(f'Dataset included continuation longer than the max seq len') # can't support continuations which are longer than the max seq len + raise Exception(f'Dataset included continuation longer than the max seq len') + # clip from the end context_enc = context_enc[-(context_max_subseq_len):] + return context_enc + + +def _get_continuation_span(context_enc: List, continuation_enc: List) -> torch.Tensor: + """ + Gets the list of indices of the continuation tokens for language modeling or generation tasks. + + Args: + context_enc (list): List of context tokens + continuation_enc (list): List of continuation tokens + + Returns: + torch.tensor: A tensor containing indices corresponding to continuation tokens + """ + return torch.tensor(range(len(context_enc), len(context_enc) + len(continuation_enc))) + + +def _make_padded_input(context_enc: List, + continuation_enc: List, + max_seq_len: int, + pad_tok_id: int, + padding_side: str = 'right') -> torch.Tensor: + """ + Takes an encoded context and continuation and clips the beginning of the context if they're too long. + Adds the padding token to the specified side. + + Args: + context_enc (List): The encoded input to the model + continuation_enc (List): The encoded desired output for the example + max_seq_list (int): Maximum length sequences can be + pad_tok_id (int): The token id we pad with + padding_side (str): Which side to pad the context on. Can be 'right' or 'left + + Returns: + input (torch.tensor): The padded and encoded context + continuation_span (torch.tensor): The _inclusive_ range of indices corresponding to the continuation + """ - # continuation span is the _inclusive_ range of indices corresponding to the continuation - continuation_span = torch.tensor(range(len(context_enc), len(context_enc) + len(continuation_enc))) inp = torch.tensor( (context_enc + continuation_enc), dtype=torch.long, ) (inp_len,) = inp.shape + # Sometimes tokenizers that have neither a pad_tok_id or eos_tok_id will pass None in as the padding + # token and cause errors + if not isinstance(pad_tok_id, int): + raise ValueError(f'`pad_tok_id` must be an integer. Found {type(pad_tok_id)} instead') # pad length from seq to padding_length if padding_side == 'right': inp = torch.cat( @@ -83,1109 +154,1207 @@ def _make_padded_input(context_enc, continuation_enc, max_seq_len, pad_tok_id, p else: raise ValueError(f"Unknown padding_side {padding_side}. padding_side must be either 'left' or 'right'") - return inp, continuation_span + return inp + + +def convert_tokens_to_tensors(batch: Dict, tokenize_labels: bool) -> Dict[str, Any]: + """ + HF Datasets converts tensors into lists when we store them, and we don't want to use `type='torch'` + because some content in the dataset, like generation args or single ints, should not be converted. + + Here, we convert those lists of tokens back into tensors in order to feed them into the model. + + Args: + batch (dict): A dictionary of batched inputs + tokenize_labels (bool): Whether or not the labels are tokenized (and need to be stacked) + Returns: + dict: The batch with torch tensors in the corresponding keys instead of lists of lists + """ + batch['input_ids'] = torch.stack(list(map(torch.tensor, batch['input_ids']))) + if tokenize_labels: + batch['labels'] = torch.stack(list(map(torch.tensor, batch['labels']))) + batch['continuation_indices'] = list(map(torch.tensor, batch['continuation_indices'])) + return batch -def _get_fewshot_sample_idxs(dataset_size: int, num_fewshot: int, sample_idx: int, rng: random.Random): - # samples without replacement. if num_fewshot exceeds the number of unique samples, - # then we will have fewer than num_fewshot examples in context + +def _get_fewshot_sample_idxs(dataset_size: int, num_fewshot: int, example_idx: int, rng: random.Random) -> Set[int]: + """ + Samples indices without replacement. If num_fewshot exceeds the number of unique examples in the dataset, + then we will have fewer than num_fewshot examples in context. + Args: + dataset_size (int): Length of the dataset + num_fewshot (int): Number of examples to prepend + example_idx (int): Current example's index (excluded from fewshot choices) + rng (random.Random): RNG for repeatable sample selection + + Returns: + list: Indices of the examples chosen for fewshot selection + """ num_fewshot = min(dataset_size - 1, num_fewshot) fewshot_idxs = set(rng.sample(range(0, dataset_size), num_fewshot)) - if sample_idx in fewshot_idxs: - fewshot_idxs.remove(sample_idx) + if example_idx in fewshot_idxs: + fewshot_idxs.remove(example_idx) if len(fewshot_idxs) >= dataset_size - 1: return fewshot_idxs replacement_sample = rng.choice(range(0, dataset_size)) - while replacement_sample in fewshot_idxs or replacement_sample == sample_idx: + while replacement_sample in fewshot_idxs or replacement_sample == example_idx: replacement_sample = rng.choice(range(0, dataset_size)) fewshot_idxs.add(replacement_sample) return fewshot_idxs -class InContextLearningQATaskDataset(Dataset): - """A dataset that construct batches for in-context learning question answering evaluation - - The input format is expected to be a jsonl file with the following fields: - - context: the question - - answer: the preferred answer to the question - - aliases: a list of aliases for the answer +class InContextLearningDataset(Dataset): + """ + A base dataset that constructs batches for in-context learning task evaluations. + The dataset format is expected to be a local jsonl file, a cloud link to a jsonl file, or a Hugging Face dataset link. + 'context' refers to the input a model will recieve before generating an output. For example, the question in question answering tasks, + the preceding text in a language modeling task, or the document and question regarding the document in a document understanding task. + 'example' refers to a loaded dictionary, generally containing a context, an answer, and any other information needed to run the task. + 'answer' refers to the desired output of the model. + + When creating a new ICL Dataset, it is likely that you will need to reimplement the following methods: + + - construct_context(): Takes a single example dictionary and formulates the context as a string for that eval question. + - get_answer_from_example(): Takes a single example dictionary and formulates the correct, ground truth answer as a string. + - tokenize_example(): Tokenizes the example and adds any extra content from the original dictionary that needs to be passed downstream. + - read_dataset(): Loads the dataset and does basic parsing. If additional parsing must be done, this is a good place to do so (See InContextLearningQATaskDataset.read_dataset()) + + Additionally, base_batch and batch_mapping must be defined. + + - base_batch (Dict): The base dictionary that the dataset will use to construct a batch. This should contain static values, like generation_kwargs or mode, + and empty lists for values that will need to be accumulated from each example. + NOTE: Sometimes you will need to set base_batch directly after the init call, e.g. in order to use class variables + like self.pad_tok_id or self.max_answer_length. If you manually set generation_kwargs this way, you'll need to call self.update_generation_kwargs() + after setting self.base_batch. + - batch_mapping (Dict): A mapping with keys that are keys in the batch and values that are columns in the loaded dataset. + collate_fn will use this mapping to create batches from self.dataset. Args: - dataset_uri (str): Either a local path, or a remote path beginning with ``s3://``, or another backend - supported by :meth:`composer.utils.maybe_create_object_store_from_uri`. Dataset must consist of rows of JSON data points with "context", - "answer", and "aliases". See tests/datasets/local_data/triviaqa_small.jsonl. - tokenizer (Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast]): The tokenizer used to map between strings and token ids - batch_size (int): Size of a batch used for eval - max_seq_len (int): The maximum sequence length supported by the model - pad_tok_id (int): The special token reserved for padding batches - num_fewshot (int): The number of complete fewshot examples to prepend before each test example - prompt_string (str): Prompt string to put once before all fewshot examples/test examples (e.g. 'translate english to french') - example_delimiter (str): Separator that goes between individual (context, answer) pairs (e.g. '\n') - continuation_delimiter: (str): Separator that goes between context and answer in each example (e.g. '\nA: ') - destination_path (str): Temporary path to store downloaded datasets - question_prelimiter (str): String to put before each question (e.g. 'Q: ') - fewshot_random_seed (int): Random seed to use for fewshot sampling - """ + dataset_uri (str): A local path, a remote path beginning with ``s3://`` or another backend, or a HuggingFace dataset uri prepended with ``hf://``. + Alternate backends must be supported by :meth:`composer.utils.maybe_create_object_store_from_uri`. + A local dataset must consist of rows of JSON data points with task dependent fields. + The default keys expected are "context" and "answer". + tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer used to map between strings and token ids. + max_seq_len (int): The maximum sequence length supported by the model. + pad_tok_id (int): The special token used for padding batches. + num_fewshot (int): The number of complete fewshot examples to prepend before each test example. These are not identical across examples. + fewshot_random_seed (int): Random seed to use for fewshot sampling. + prompt_string (str): Prompt string to put once before all fewshot examples/test examples (e.g. 'Translate english to french.'). + example_delimiter (str): Separator inserted before (context, answer) pairs (e.g. '\\n') for fewshot sampling and prompting. + continuation_delimiter: (str): Separator inserted between context and answer in each example (e.g. '\\nA: '). + destination_path (str): Temporary path to store downloaded datasets. + prelimiter (str): Text to be prepended before each context, including few shot examples (e.g. "Question: "). + context_key (str): The key in the loaded dataset that contains the context. + answer_key (str): The key in the loaded dataset that contains the answer. + strip_dataset (bool): Boolean for whether to strip whitespace from data. Trailing whitespace can cause degenerative outputs, + so unless whitespace should be preserved (for example in code), this should be set to True. + padding_side (str): Side of the content and answer on which to apply padding. Can be either 'right' or 'left'. + padding_size (int): The final size of the tensor after padding. Defaults to max_sequence_length. + base_batch (Dict): The base dictionary upon which a batch is created. See above for more details. + base_mapping (Dict): A mapping of batch keys to dataset columns, used to create batches. See above for more details. + hf_loading_vars (Dict): A dictionary containing keyword arguments to be passed into `load_dataset` if dataset is being pulled from HF. + hf_parsing_map (Dict): A dictionary containing a mapping from HF columns to ICL dataset keys. The dictionary should be formatted {icl_key:[hf_key1, hf_key1]}. + Column contents will be concatenated with ' ' seperating them. If not included, will load the columns already present in the HF dataset. + tokenize_labels (bool): Whether or not the labels should be tokenized. Generally determined by which metric a dataset uses. + generation_kwargs (Dict): A dictionary containing keyword arguments to be passed along to the model's generate function. - def _read_dataset(self, dataset: Dataset) -> List[Dict[str, str]]: - result = [] - for example in dataset: - result.append({ - 'context': example['context'], - 'answer': example['answer'], - 'aliases': set([example['answer']] + example.get('aliases', [])), - 'chain_of_thought': example.get('chain_of_thought', '') - }) - return result + """ - def __init__(self, - dataset_uri: str, - tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast], - max_seq_len: int, - pad_tok_id: int, - num_fewshot: int, - prompt_string: str, - example_delimiter: str, - continuation_delimiter: str, - destination_path: str, - question_prelimiter: str, - fewshot_random_seed: int, - cot_delimiter: str = '', - early_stopping_criteria: Optional[List[str]] = None, - do_normalization: bool = True): - if tokenizer.eos_token_id is None: - raise ValueError('`InContextLearningQATaskDataset` tokenizer must have non-null `eos_token_id`') + def __init__( + self, + dataset_uri: str, + tokenizer: transformers.PreTrainedTokenizerBase, + max_seq_len: int, + pad_tok_id: int, + num_fewshot: int, + fewshot_random_seed: int, + prompt_string: str, + example_delimiter: str, + continuation_delimiter: str, + destination_path: str, + prelimiter: str = '', + context_key: str = 'context', + answer_key: str = 'answer', + strip_dataset: bool = True, + padding_side: str = 'right', + tokenize_labels: bool = True, + static_keys: Optional[List] = None, + list_keys: Optional[List] = None, + tensor_keys: Optional[List] = None, + padding_size: Optional[int] = None, + base_batch: Optional[Dict] = None, + batch_mapping: Optional[Dict] = None, + hf_loading_vars: Optional[Dict] = None, + hf_parsing_map: Optional[Dict] = None, + generation_kwargs: Optional[Dict] = None, + ): try: - from datasets import load_dataset # pyright: ignore [reportGeneralTypeIssues] + import datasets + del datasets except ImportError as e: - raise MissingConditionalImportError(extra_deps_group='nlp', - conda_package='datasets', - conda_channel='conda-forge') from e - with dist.local_rank_zero_download_and_wait(destination_path): - if dist.get_local_rank() == 0: - get_file(dataset_uri, destination_path, overwrite=True) - dataset = load_dataset('json', data_files=destination_path, split='train', streaming=False) - self.early_stopping_criteria = early_stopping_criteria - self.do_normalization = do_normalization - self.samples = self._read_dataset(dataset) # pyright: ignore[reportGeneralTypeIssues] - self.samples = strip_data(self.samples) + raise MissingConditionalImportError( + extra_deps_group='nlp', + conda_package='datasets', + conda_channel='conda-forge', + ) from e + self.tokenizer = tokenizer + self.prefix_space = _tokenizer_needs_prefix_space(self.tokenizer) + self.max_seq_len = max_seq_len self.pad_tok_id = pad_tok_id - self.padding_side = 'left' - self.max_answer_length = 0 + self.num_fewshot = num_fewshot + self.padding_side = padding_side + self.padding_size = padding_size if padding_size else self.max_seq_len + self.prelimiter = prelimiter + self.example_delimiter = example_delimiter + self.continuation_delimiter = continuation_delimiter + self.context_key = context_key + self.answer_key = answer_key + self.tokenize_labels = tokenize_labels + self.batch_mapping = batch_mapping or {} + self.base_batch = base_batch or {} + self.update_generation_kwargs(generation_kwargs or {}) + + self.static_keys = static_keys + self.list_keys = list_keys + self.tensor_keys = tensor_keys + + hf_loading_vars = hf_loading_vars or {} + self.dataset: HFDataset = self.read_dataset(dataset_uri, destination_path, hf_loading_vars, hf_parsing_map) + self.strip_data = strip_dataset + if self.strip_data: + self.dataset = self.dataset.map(strip_data) + fewshot_rng = random.Random(fewshot_random_seed) - self.encoded_dataset = self._prep_examples(num_fewshot, prompt_string, example_delimiter, - continuation_delimiter, question_prelimiter, fewshot_rng, - cot_delimiter) + self.dataset: HFDataset = self.dataset.map( + self._prep_example, + with_indices=True, + fn_kwargs={ + 'num_fewshot': num_fewshot, + 'prompt_string': prompt_string, + 'fewshot_rng': fewshot_rng, + }, + ) - def _format_prompt_and_fewshot(self, num_fewshot: int, prompt_string: str, example_delimiter: str, - continuation_delimiter: str, question_prelimiter: str, cot_delimiter: str, - fewshot_rng: random.Random, sample_idx: int) -> str: - """Formats the prompt fewshot examples for test sample `sample_idx`. + def __getitem__(self, index: int) -> Dict: + return self.dataset[index] - Randomly select `num_fewshot` samples from the dataset (not including the sample at `sample_idx`) and format - them each as follows `{example_delimiter}{question_prelimiter}{context}{continuation_delimiter}{chain_of_thought}{cot_delimiter}{answer}`. + def __len__(self) -> int: + return len(self.dataset) - `chain_of_thought` will default to empty if not present in the dataset but `context` and `answer` must be present. + def get_num_samples_in_batch(self, batch: Dict) -> int: + return batch['input_ids'].shape[0] + + def update_generation_kwargs(self, generation_kwargs: Dict) -> None: + """ + Updates self.base_batch with the passed in generation_kwargs. + This must be run after self.base_batch is set (for example, if self.base_batch is set after __init__() is run, + likely because base_batch needs a class variable like self.pad_tok_id or self.max_answer_length). - Returns the formatted prompt_string + concatenated list of formatted few shot examples. + Args: + dict: Keyword arguments that be written into base_batch['generation_kwargs'] """ - prompt_and_fewshot = prompt_string + if 'generation_kwargs' not in self.base_batch: + self.base_batch['generation_kwargs'] = {} + if generation_kwargs: + self.base_batch['generation_kwargs'].update(generation_kwargs) + + def read_dataset(self, + dataset_uri: str, + destination_path: str, + hf_loading_vars: Optional[Dict[str, Any]] = None, + hf_parsing_map: Optional[Dict[str, Any]] = None) -> 'HFDataset': + """ + Reads a dataset and handles parsing it from HuggingFace. + + Args: + dataset_uri (str): A local path, a remote path beginning with ``s3://`` or another backend, or a HuggingFace dataset uri. + Alternate backends must be supported by :meth:`composer.utils.maybe_create_object_store_from_uri`. + destination_path (str): A local path where the data will be stored + hf_loading_vars (Dict): If parsing from HuggingFace, keyword args that will be passed into load_dataset + hf_parsing_map (Dict): Dictionary in the form of {icl_key: [hf_col1, hf_col2]} that will map one or more hf columns, in order, to ICL dataset columns + + Returns: + dataset: A loaded HF dataset + """ + from datasets import Dataset as HFDataset # pyright: ignore[reportGeneralTypeIssues] + from datasets import load_dataset # pyright: ignore[reportGeneralTypeIssues] + if 'hf://' in dataset_uri: + dataset_uri = dataset_uri.replace('hf://', '') + if hf_loading_vars is None: + hf_loading_vars = {} + dataset = load_dataset(dataset_uri, **hf_loading_vars) + if hf_parsing_map: + dataset_parsing_func = lambda example: { + k: ' '.join([str(example[col]) for col in v]) + for k, v in hf_parsing_map.items() # pyright: ignore[reportOptionalMemberAccess] + } + assert isinstance(dataset, HFDataset) + dataset = dataset.map(dataset_parsing_func, remove_columns=dataset.column_names) + else: + with dist.local_rank_zero_download_and_wait(destination_path): + if dist.get_local_rank() == 0: + get_file(dataset_uri, destination_path, overwrite=True) + dataset = load_dataset('json', data_files=destination_path, split='train', streaming=False) + assert isinstance(dataset, HFDataset) + return dataset + + def _generate_few_shot_prompt( + self, + num_fewshot: int, + example_idx: int, + preamble: str, + fewshot_rng: random.Random, + ) -> str: + """ + Formats the fewshot prompt for test example `example_idx`. + + Randomly selects `num_fewshot` samples from the dataset (excluding the example at `example_idx`) and constructs + contextes with answers appended. + + Returns the formatted prompt_string + concatenated list of formatted few shot examples as a string. + + Args: + num_fewshot (int): Number of examples to prepend + example_idx (int): Current example idx + preamble (str): Text to occur at the beginning of the task. Generally instructions or a prompt. + fewshot_rng (random.Random): Seeded sampler to chose samples with + + Returns: + str: The original preamble with num_fewshot examples appended + """ + few_shot_text = preamble if num_fewshot > 0: - fewshot_idxs = _get_fewshot_sample_idxs(len(self.samples), num_fewshot, sample_idx, fewshot_rng) + fewshot_idxs = _get_fewshot_sample_idxs( + len(self.dataset), + num_fewshot, + example_idx, + fewshot_rng, + ) for fewshot_idx in fewshot_idxs: - context = self.samples[fewshot_idx]['context'] - chain_of_thought = self.samples[fewshot_idx].get('chain_of_thought', '') - answer = self.samples[fewshot_idx]['answer'] - - if len(chain_of_thought) == 0: - cot_delimiter = '' - context = f'{question_prelimiter}{context}' - if len(prompt_and_fewshot) > 0: - context = f'{example_delimiter}{context}' - prompt_and_fewshot += f'{context}{continuation_delimiter}{chain_of_thought}{cot_delimiter}{answer}' - - return prompt_and_fewshot - - def _prep_examples(self, - num_fewshot: int, - prompt_string: str, - example_delimiter: str, - continuation_delimiter: str, - question_prelimiter: str, - fewshot_rng: random.Random, - cot_delimiter: str = '') -> List[Dict[str, Any]]: - """Prepares a set of language modeling tasks into tokenized format with prompt and fewshot examples. + ctxt = self.construct_context( + self.dataset[fewshot_idx], + few_shot_text, + add_answer=True, + ) + few_shot_text += ctxt + + return few_shot_text + + def construct_context(self, example: Dict, preceding_text: str = '', add_answer: bool = False) -> str: + """ + Takes an example and constructs a context, i.e. the input the model reads for this example. + Optionally adds the correct answer (for fewshot examples) and handles example delimiters + + Args: + example (Dict): The example from which to construct the context + preceding_text (str): Any preceding text, used as a check for prepending self.example_delimiter + add_answer (bool): Bool for whether or not to add the answer on the end of the context (e.g. for fewshot examples) + + Returns: + str: The constructed context. The default output context is + formatted as follows: f'{self.prelimiter}{example[self.context_key]}{self.continuation_delimiter}' + """ + ctxt = example[self.context_key] + ctxt = f'{self.prelimiter}{ctxt}' + if len(preceding_text) > 0: + ctxt = f'{self.example_delimiter}{ctxt}' + ctxt = f'{ctxt}{self.continuation_delimiter}' + if add_answer: + ctxt = f'{ctxt}{self.get_answer_from_example(example, in_context=add_answer)}' + return ctxt + + def get_answer_from_example(self, example: Dict[str, Any], in_context: bool = False) -> str: + """ + Returns the answer from the example. + + Args: + example (Dict): The example from which to retrieve the answer + + Returns: + str: The answer in the example + """ + cont = example[self.answer_key] + if self.prefix_space and not cont.startswith(' ') and not in_context: + cont = f' {cont}' + return cont + + def _fix_eos_on_preamble(self, input_ids: List[int]) -> List[int]: + """ + If the input_ids is empty then input_ids will be a 0-length List + unless the tokenizer adds special tokens to empty strings (e.g. OPT tokenizer). + If there is an EOS token added, we need to remove it so it is not in the middle of the prompt, + as the specific eval question's prompt will follow the input_ids. + + Args: + input_ids (List): The tokenized input + + Returns: + input_ids: The tokenized input conditionally edited + """ + if (self.tokenizer.eos_token_id is not None and len(input_ids) > 1 and + input_ids[-1] == self.tokenizer.eos_token_id): + input_ids = input_ids[:-1] + return input_ids + + def tokenize_example(self, prompt_and_fewshot: str, ctxt: str, example: Dict) -> Dict[str, Any]: + """ + Runs text through the tokenizer and handle special cases. + + Args: + prompt_and_fewshot (str): The collection of the prompt and fewshot examples that belongs before the example's context + ctxt (str): The specific example's derrived context + example (Dict): The example as a dictionary. Used for additional processing in inherited classes. + + Returns: + Dict: Dictionary with the tokenized data + """ + tokenized_example = {} + # Always add special tokens to preamble + preamble = self.tokenizer(prompt_and_fewshot)['input_ids'] + assert isinstance(preamble, list) + preamble = self._fix_eos_on_preamble(preamble) + if self.strip_data: + # rstrip context because a prompt ending in a space results in degenerate output + ctxt = ctxt.rstrip() + # Never add special tokens to context + tokenized_context = self.tokenizer(ctxt, add_special_tokens=False)['input_ids'] + assert isinstance(preamble, list) + assert isinstance(tokenized_context, list) + + tokenized_context = preamble + tokenized_context + + if self.tokenize_labels: + # Never add special tokens to answer + tokenized_answer = self.tokenizer(self.get_answer_from_example(example), + add_special_tokens=False)['input_ids'] + assert isinstance(tokenized_answer, list) + trimmed_context = _trim_context(tokenized_context, tokenized_answer, self.padding_size) + assert isinstance(trimmed_context, list) + continuation_indices = _get_continuation_span(trimmed_context, tokenized_answer) + padded_context = _make_padded_input(trimmed_context, tokenized_answer, self.padding_size, self.pad_tok_id, + self.padding_side) + + tokenized_example[self.context_key] = padded_context + tokenized_example[self.answer_key] = tokenized_answer + tokenized_example['continuation_indices'] = continuation_indices + else: + assert isinstance(tokenized_context, list) + trimmed_context = _trim_context( + tokenized_context, + [], + self.padding_size, + ) + assert isinstance(trimmed_context, list) + padded_context = _make_padded_input(trimmed_context, [], self.padding_size, self.pad_tok_id, + self.padding_side) + + tokenized_example[self.context_key] = padded_context + tokenized_example[self.answer_key] = self.get_answer_from_example(example) + + return tokenized_example + + def _prep_example( + self, + example: Dict, + example_idx: int, + num_fewshot: int, + prompt_string: str, + fewshot_rng: random.Random, + ) -> Dict[str, Any]: + """ + Prepares a single example from a HF Dataset into tokenized format with prompt and fewshot examples. Each task consists of a context and a continuation as well as an optional prompt and optional list of example context/continuation pairs which precede the test context/continuation pair. Args: + example (Dict): A Dictionary from the hf dataset + example_idx (int): The index of example num_fewshot (int): Number of examples context/continuation pairs to prepend to the test pair prompt_string (str): The prompt to prepend to all inputs - example_delimiter (str): The delimiter used to separate each individual context/continuation pair - continuation_delimiter (str): The delimiter used to separate each context from its continuation - question_prelimiter (str): The text to prepend to each question fewshot_rng (random.Random): Random number generator to use for fewshot sampling - cot_delimiter (str): The delimiter used to separate the chain-of-thought (if present) from the final model response. + Returns: + Dict: Contains a dictionary with the tokenized data + """ + prompt_and_fewshot = self._generate_few_shot_prompt(num_fewshot, example_idx, prompt_string, fewshot_rng) + ctxt = self.construct_context(example, prompt_and_fewshot, add_answer=False) + tokenized_example = self.tokenize_example(prompt_and_fewshot, ctxt, example) + return tokenized_example + + def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + The function that the dataloader uses to accumulate data into batches. + + Args: + data (List): List of tokenized datapoints (dicts returned by self._tokenize_example) Returns: - dict: Contains the context, the continuation, and the preamble (prompt + fewshot examples) + Dict: Dictionary for a single batch """ - max_answer_length = 0 - has_cot = False - examples = [] - for sample_idx in tqdm(range(len(self.samples))): - encoded_example = {} - - prompt_and_fewshot = self._format_prompt_and_fewshot(num_fewshot, prompt_string, example_delimiter, - continuation_delimiter, question_prelimiter, - cot_delimiter, fewshot_rng, sample_idx) - - ctxt = self.samples[sample_idx]['context'] - ctxt = f'{question_prelimiter}{ctxt}' - if len(prompt_and_fewshot) > 0: - ctxt = f'{example_delimiter}{ctxt}' - - # rstrip the continuation delimiter, because the prompt ending in a space results in degenerate output - continuation_delimiter_stripped = continuation_delimiter.rstrip() - ctxt = f'{ctxt}{continuation_delimiter_stripped}' - - # If the preamble is empty then this will be a 0-length list, unless the tokenizer adds special tokens to empty strings (e.g. OPT tokenizer) - encoded_example['preamble'] = self.tokenizer(prompt_and_fewshot) - # If there is an EOS token added, we need to remove it so it is not in the middle of the prompt - example_ids = encoded_example['preamble']['input_ids'] - if (self.tokenizer.eos_token_id is not None and - len(example_ids) > 1 and # pyright: ignore[reportGeneralTypeIssues] - example_ids[-1] == self.tokenizer.eos_token_id): # pyright: ignore[reportGeneralTypeIssues] - encoded_example['preamble']['input_ids'] = example_ids[:-1] # pyright: ignore[reportGeneralTypeIssues] - - encoded_example['context'] = self.tokenizer(ctxt, add_special_tokens=False) - encoded_example['aliases'] = list(self.samples[sample_idx]['aliases']) - encoded_example['cot_delimiter'] = cot_delimiter - examples.append(encoded_example) - for answer in self.samples[sample_idx]['aliases']: - response = f"{self.samples[sample_idx]['chain_of_thought']}{cot_delimiter}{answer}" - max_answer_length = max(max_answer_length, len( - self.tokenizer(response)['input_ids'])) # pyright: ignore[reportGeneralTypeIssues] - - if len(self.samples[sample_idx]['chain_of_thought']) > 0: - has_cot = True - - self.max_answer_length = max_answer_length + (_MAX_ANSWER_BUFFER_LENGTH if has_cot else 0) - return examples - - def __getitem__(self, index): - return self.encoded_dataset[index] - - def __len__(self): - return len(self.encoded_dataset) - - def collate_fn(self, data): - inputs, answers = [], [] - cot_delimiter = '' - - for sample in data: - preamble, context, aliases = (sample['preamble'], sample['context'], sample['aliases']) - context_enc = preamble['input_ids'] + context['input_ids'] - inp, _ = _make_padded_input(context_enc, [], - self.max_seq_len - self.max_answer_length, - self.pad_tok_id, - padding_side=self.padding_side) - - inputs.append(inp) - answers.append(aliases) - - # We will search for the answer within the portion of the model response - # beginning with `cot_delimiter` - cot_delimiter = sample['cot_delimiter'] - stopping_criteria = None - if self.early_stopping_criteria: - if stop_sequences_criteria is None: # pyright: ignore [reportUnnecessaryComparison] - raise MissingConditionalImportError(extra_deps_group='nlp', - conda_package='transformers', - conda_channel='conda-forge') - stopping_criteria = stop_sequences_criteria(self.tokenizer, self.early_stopping_criteria, len(inputs)) - batch = { - 'input_ids': torch.stack(inputs), - 'mode': 'generate', - 'labels': answers, - 'cot_delimiter': cot_delimiter, - 'generation_length': self.max_answer_length, - 'stopping_criteria': self.early_stopping_criteria, - 'do_normalization': self.do_normalization, - 'generation_kwargs': { - 'pad_token_id': self.pad_tok_id, - 'use_cache': True, - 'stopping_criteria': stopping_criteria, - 'eos_token_id': self.tokenizer.eos_token_id, - } - } + batch = copy.deepcopy(self.base_batch) + for data_pair in data: + for batch_key, data_key in self.batch_mapping.items(): + batch[batch_key].append(data_pair[data_key]) + if 'continuation_indices' in data_pair: + batch['continuation_indices'].append(data_pair['continuation_indices']) + batch = convert_tokens_to_tensors(batch, self.tokenize_labels) batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id) return batch - def get_num_samples_in_batch(self, batch) -> int: - return batch['input_ids'].shape[0] + def split_batch(self, batch: Any, microbatch_size: int) -> List[Dict[str, Any]]: + """ + Handling for certain specialty columns that must be split into batches in different formats. + + Args: + batch (Dict): Batch of data + microbatch_size (int): Size of microbatches - def split_batch(self, batch: Any, microbatch_size: int): + Returns: + List: List of chunked batches + """ # Don't split kwargs that don't change # Normally split torch tensors # List split lists of strings - no_split = [ - 'mode', 'generation_length', 'generation_kwargs', 'cot_delimiter', 'do_normalization', 'stopping_criteria' - ] - normal_split = ['input_ids', 'attention_mask'] - list_split = ['labels'] chunked = {} for k, v in batch.items(): - if k in no_split: + if k in self.static_keys: # Defer broadcasting until we know num_chunks pass - elif k in list_split: + elif k in self.list_keys: chunked[k] = _split_list(v, microbatch_size) - elif k in normal_split: + elif k in self.tensor_keys: chunked[k] = _default_split_batch(v, microbatch_size) else: - raise ValueError(f'Unexpected key {k}') + raise ValueError(f'Unexpected key {k} in batch splitting') num_chunks = len(chunked['input_ids']) for k, v in batch.items(): - if k in no_split: + if k in self.static_keys: chunked[k] = [v] * num_chunks - return [{k: v[idx] for k, v in chunked.items()} for idx in range(num_chunks)] + batched_list = [{k: v[idx] for k, v in chunked.items()} for idx in range(num_chunks)] + return batched_list -class InContextLearningLMTaskDataset(Dataset): - """A dataset that construct batches for in-context learning language modeling evaluation - Args: - dataset_uri (str): Either a local path, or a remote path beginning with ``s3://``, or another backend - supported by :meth:`composer.utils.maybe_create_object_store_from_uri`. Dataset must consist of rows of JSON data points with "context", - and "continuation". See tests/datasets/local_data/lambada_small.jsonl. - tokenizer (Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast]): The tokenizer used to transform data into batches - batch_size (int): Size of a batch used for eval - max_seq_len (int): The sequence length expected by the model - pad_tok_id (int): The special token reserved for padding the ends of batches - num_fewshot (int): The number of complete fewshot examples to prepend before each test example - prompt_string (str): Prompt string to put once before all fewshot examples/test examples (e.g. 'translate english to french') - example_delimiter (str): Separator that goes between individual (context, continuation) pairs (e.g. '\n') - continuation_delimiter: (str): Separator that goes between context and continuation in each example (e.g. '->') - destination_path (str): Temporary path to store downloaded datasets - fewshot_random_seed (int): Random seed used to select fewshot examples +class InContextLearningQATaskDataset(InContextLearningDataset): """ + A dataset that constructs batches for in-context learning question answering evaluation. + QA tasks evaluate a model's ability to answer questions using a consistent format. - def __init__( - self, - dataset_uri: str, - tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast], - max_seq_len: int, - pad_tok_id: int, - num_fewshot: int, - prompt_string: str, - example_delimiter: str, - continuation_delimiter: str, - destination_path: str, - fewshot_random_seed: int, - ): - try: - from datasets import load_dataset # pyright: ignore [reportGeneralTypeIssues] - except ImportError as e: - raise MissingConditionalImportError(extra_deps_group='nlp', - conda_package='datasets', - conda_channel='conda-forge') from e - with dist.local_rank_zero_download_and_wait(destination_path): - if dist.get_local_rank() == 0: - get_file(dataset_uri, destination_path, overwrite=True) - dataset = load_dataset('json', data_files=destination_path, split='train', streaming=False) - self.samples = list( - dataset.map(lambda examples: { - 'continuation': examples['continuation'], - 'context': examples['context'], - })) - self.samples = strip_data(self.samples) - - self.tokenizer = tokenizer - self.max_seq_len = max_seq_len - self.pad_tok_id = pad_tok_id - fewshot_rng = random.Random(fewshot_random_seed) + The input format is expected to be a jsonl file with the following fields: + - context: The question + - answer: The preferred answer to the question + - aliases: A list of aliases for the answer - self.prefix_space = _tokenizer_needs_prefix_space(self.tokenizer) + See InContextLearningDataset for more details. - self.encoded_dataset = self.prep_examples(num_fewshot, prompt_string, example_delimiter, continuation_delimiter, - fewshot_rng) + Additional Args: + cot_delimiter (str): Delimiter to place between the chain of thought and continuations. + """ - def prep_examples(self, num_fewshot: int, prompt_string: str, example_delimiter: str, continuation_delimiter: str, - fewshot_rng: random.Random): - """Prepares a set of language modeling tasks into tokenized format with prompt and fewshot examples. + def __init__(self, + cot_delimiter: str = '', + early_stopping_criteria: Optional[List[str]] = None, + do_normalization: bool = True, + *args, + **kwargs): + if kwargs['tokenizer'].eos_token_id is None: + raise ValueError('`InContextLearningQATaskDataset` tokenizer must have non-null `eos_token_id`') + self.cot_delimiter = cot_delimiter + self.has_cot = False + self.max_answer_length = 0 + static_keys = [ + 'mode', 'cot_delimiter', 'generation_length', 'generation_kwargs', 'do_normalization', 'stopping_criteria' + ] + tensor_keys = ['input_ids', 'attention_mask'] + list_keys = ['labels'] + super().__init__(padding_side='left', + tokenize_labels=False, + static_keys=static_keys, + list_keys=list_keys, + tensor_keys=tensor_keys, + *args, + **kwargs) + # NOTE: set these after init call because they take class vars + self.early_stopping_criteria = early_stopping_criteria + self.base_batch = { + 'input_ids': [], + 'mode': 'generate', + 'labels': [], + 'cot_delimiter': self.cot_delimiter, + 'generation_length': self.max_answer_length, + 'stopping_criteria': early_stopping_criteria, + 'do_normalization': do_normalization, + 'generation_kwargs': { + 'pad_token_id': self.pad_tok_id, + 'use_cache': True, + 'eos_token_id': self.tokenizer.eos_token_id, + } + } + self.batch_mapping = { + 'input_ids': self.context_key, + 'labels': 'aliases', + } + self.update_generation_kwargs(kwargs.get('generation_kwargs', {})) - Each task consists of a context and a continuation as well as an optional prompt and optional list of - example context/continuation pairs which precede the test context/continuation pair. + def read_dataset( + self, + dataset_uri: str, + destination_path: str, + hf_loading_vars: Dict, + hf_parsing_map: Dict, + ) -> 'HFDataset': + dataset = super().read_dataset(dataset_uri, destination_path, hf_loading_vars, hf_parsing_map) + self.has_cot = 'chain_of_thought' in dataset.features + dataset = dataset.map( + lambda examples: { + 'context': examples['context'], + 'answer': examples['answer'], + 'aliases': set([examples['answer']] + examples.get('aliases', [])), + 'chain_of_thought': examples.get('chain_of_thought', ''), + }) + self.max_answer_length = self._get_max_answer_length(dataset) + # NOTE: This is the only time we use the class variable padding_size. + self.padding_size = self.max_seq_len - self.max_answer_length + return dataset + def get_answer_from_example(self, example: Dict, in_context=False) -> str: + """ + Returns the answer from the example. Applies chain of thought if self.has_cot is marked as true. Args: - num_fewshot (int): Number of examples context/continuation pairs to prepend to the test pair - prompt_string (str): The prompt to prepend to all inputs - example_delimiter (str): The delimiter used to separate each individual context/continuation pair - continuation_delimiter (str): The delimiter used to separate each context from its continuation - fewshot_rng (random.Random): Random number generator used to select fewshot examples + example (Dict): The example from which to retrieve the answer Returns: - dict: Contains the context, the continuation, and the preamble (prompt + fewshot examples) + str: The answer in from the example with chain of thought and delimiter if needed """ - examples = [] - for sample_idx in tqdm(range(len(self.samples))): - encoded_example = {} - - preamble = prompt_string - - if num_fewshot > 0: - fewshot_idxs = _get_fewshot_sample_idxs(len(self.samples), num_fewshot, sample_idx, fewshot_rng) - for fewshot_idx in fewshot_idxs: - ctxt, cont = self.samples[fewshot_idx]['context'], self.samples[fewshot_idx]['continuation'] - if len(preamble) > 0: - ctxt = f'{example_delimiter}{ctxt}' - preamble += f'{ctxt}{continuation_delimiter}{cont}' - - ctxt, cont = self.samples[sample_idx]['context'], self.samples[sample_idx]['continuation'] - if len(preamble) > 0: - ctxt = f'{example_delimiter}{ctxt}' - - # rstrip the continuation delimiter, because the prompt ending in a space results in degenerate output - continuation_delimiter_stripped = continuation_delimiter.rstrip() - - if self.prefix_space and not cont.startswith(' '): - cont = f' {cont}' - ctxt += continuation_delimiter_stripped - - encoded_example['preamble'] = self.tokenizer( - preamble - ) # if the preamble is empty then these will be 0-length lists, unless the tokenizer adds special tokens to empty strings (e.g. OPT tokenizer) - example_ids = encoded_example['preamble']['input_ids'] - if (self.tokenizer.eos_token_id is not None and - len(encoded_example['preamble']['input_ids']) > 1 and # pyright: ignore[reportGeneralTypeIssues] - example_ids[-1] == self.tokenizer.eos_token_id): # pyright: ignore[reportGeneralTypeIssues] - encoded_example['preamble']['input_ids'] = example_ids[:-1] # pyright: ignore[reportGeneralTypeIssues] + if self.has_cot: + return f'{example["chain_of_thought"]}{self.cot_delimiter}{example[self.answer_key]}' + else: + return example[self.answer_key] - encoded_example['context'] = self.tokenizer(ctxt, add_special_tokens=False) - encoded_example['continuation'] = self.tokenizer(cont, add_special_tokens=False) - - examples.append(encoded_example) - - return examples - - def __getitem__(self, index): - return self.encoded_dataset[index] - - def __len__(self): - return len(self.encoded_dataset) - - def collate_fn(self, data): - inputs = [] - continuation_indices = [] - for data_pair in data: - preamble, context, continuation = (data_pair['preamble'], data_pair['context'], data_pair['continuation']) + def tokenize_example(self, prompt_and_fewshot: str, ctxt: str, example: Dict) -> Dict[str, Any]: + """ + Run text through the tokenizer and handle special cases. + Args: + prompt_and_fewshot (str): The collection of the prompt and fewshot examples that belongs before the example's context + ctx (str): The specific example's derrived context + example (Dict): The example as a dictionary. - context_enc = preamble['input_ids'] + context['input_ids'] - continuation_enc = continuation['input_ids'] + Returns: + Dict: Dictionary with the tokenized data + """ + tokenized_example = super().tokenize_example(prompt_and_fewshot, ctxt, example) + tokenized_example['aliases'] = list(example.get('aliases', [])) + return tokenized_example - inp, continuation_span = _make_padded_input(context_enc, continuation_enc, self.max_seq_len, - self.pad_tok_id) + def _get_max_answer_length(self, dataset) -> int: + f""" + Loops over the dataset and finds the longest answer length. - inputs.append(inp) - continuation_indices.append(continuation_span) + Returns: + int: The maximum answer length with an additional buffer of {_MAX_ANSWER_BUFFER_LENGTH} if chain of thought is present + """ + max_answer_length = 0 + for example in dataset: + all_answers = [example[self.answer_key]] + list(example.get('aliases', [])) + for answer in all_answers: + if self.has_cot: + response = (f'{example["chain_of_thought"]}{self.cot_delimiter}{answer}') + else: + response = answer + tokenized_repsonse = self.tokenizer(response)['input_ids'] + assert isinstance(tokenized_repsonse, list) + max_answer_length = max(max_answer_length, len(tokenized_repsonse)) + max_answer_length = max_answer_length + (_MAX_ANSWER_BUFFER_LENGTH if len(self.cot_delimiter) > 0 else 0) + return max_answer_length + + def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + batch = super().collate_fn(data) + batch_size = batch['input_ids'].shape[0] + stopping_criteria = None + if self.early_stopping_criteria: + if stop_sequences_criteria is None: # pyright: ignore [reportUnnecessaryComparison] + raise MissingConditionalImportError(extra_deps_group='nlp', + conda_package='transformers', + conda_channel='conda-forge') + stopping_criteria = stop_sequences_criteria(self.tokenizer, self.early_stopping_criteria, batch_size) + batch['generation_kwargs']['stopping_criteria'] = stopping_criteria + return batch - batch = { - 'input_ids': torch.stack(inputs), - 'continuation_indices': continuation_indices, - 'mode': 'icl_task', - 'labels': torch.stack(inputs), - } - batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id) - return batch +class InContextLearningLMTaskDataset(InContextLearningDataset): + """ + A dataset that constructs batches for in-context learning language modeling evaluation. + Language modeling tasks test a model's ability to properly predict tokens based on preceding tokens. - def get_num_samples_in_batch(self, batch) -> int: - return batch['input_ids'].shape[0] + The input format is expected to be a jsonl file with the following fields: + - context: Preceding text + - continuation: The expected continuation + See InContextLearningDataset for more details. + """ -class InContextLearningMultipleChoiceTaskDataset(Dataset): - """A dataset that construct batches for in-context learning multiple choice evaluation + def __init__(self, *args, **kwargs): + super().__init__(answer_key='continuation', + static_keys=['mode'], + tensor_keys=['input_ids', 'continuation_indices', 'labels', 'attention_mask'], + base_batch={ + 'input_ids': [], + 'continuation_indices': [], + 'mode': 'icl_task', + 'labels': [] + }, + batch_mapping={ + 'input_ids': 'context', + 'labels': 'context' + }, + padding_side='right', + *args, + **kwargs) + + +class InContextLearningMultipleChoiceTaskDataset(InContextLearningDataset): + """ + A dataset that construct batches for in-context learning multiple choice evaluation. If each question has N answer choices, we construct N distinct inputs per question. In order to ensure consistency across multi-GPU, we set the batch size to be `min(N, batch_size)` so that all N inputs per question can stored in the same batch. - Each batch then consists of batch_size // N distinct questions and has the following the structure - - 'input_ids': Input tensor batch x seqlen x # tokens - 'continuation_indices': List of |batch| consisting of tensors indicating which indices in the sequence correspond to the question answer (aka continuation) - 'mode': Indicates to the model that this is an ICL task and may rely on a custom code path to properly update metrics - 'labels': Identical to the input, used by the model to calculate loss/metrics - 'gold_indices': List of length |batch_size // N| indicating for each question, which of the answers is correct (via an integer [0, N-1]) - 'choice_groupings': Indicates which indices of the batch correspond to which questions - - Args: - dataset_uri (str): Either a local path, or a remote path beginning with ``s3://``, or another backend - supported by :meth:`composer.utils.maybe_create_object_store_from_uri`. Dataset must consist of rows of JSON data points with "query", - "choices", and "gold" index. See tests/datasets/local_data/piqa_small.jsonl. - tokenizer (Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast]): The tokenizer used to transform data into batches - batch_size (int): Size of a batch used for eval - max_seq_len (int): The sequence length expected by the model - pad_tok_id (int): The special token reserved for padding the ends of batches - num_fewshot (int): The number of complete fewshot examples to prepend before each test example - prompt_string (str): Prompt string to put once before all fewshot examples/test examples (e.g. 'translate english to french') - example_delimiter (str): Separator that goes between individual (context, continuation) pairs (e.g. '\n') - continuation_delimiter: (str): Separator that goes between context and continuation in each example (e.g. '->') - destination_path (str): Temporary path to store downloaded datasets - fewshot_random_seed (int): Random seed used to select fewshot examples + The default input format is a jsonl file with the following fields: + - query: The preceding text, question, or document relevant to the choices + - gold: Index of the correct choice under 'choices' + - choices: A list of strings, each being one of the potential choices + + Each batch then consists of ``|batch_size // N|`` distinct questions and has the following the structure. + - input_ids: Input tensor ``|batch x seqlen x # tokens|`` + - continuation_indices: List of ``|batch|`` consisting of tensors indicating which indices in the sequence correspond to the question answer (aka continuation) + - mode: Indicates to the model that this is an ICL task and may rely on a custom code path to properly update metrics + - labels: Identical to the input, used by the model to calculate loss/metrics + - gold_indices: List of length ``|batch_size // N|`` indicating for each question, which of the answers is correct (via an integer [0, N-1]) + - choice_groupings: Indicates which indices of the batch correspond to which questions + + Additional Args: + choices_key (str): The key under which the choices are stored in the saved dataset. Defaults to 'choices'. """ - def __init__( - self, - dataset_uri: str, - tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast], - max_seq_len: int, - pad_tok_id: int, - num_fewshot: int, - prompt_string: str, - example_delimiter: str, - continuation_delimiter: str, - destination_path: str, - fewshot_random_seed: int, - ): - try: - from datasets import load_dataset # pyright: ignore [reportGeneralTypeIssues] - except ImportError as e: - raise MissingConditionalImportError(extra_deps_group='nlp', - conda_package='datasets', - conda_channel='conda-forge') from e - - with dist.local_rank_zero_download_and_wait(destination_path): - if dist.get_local_rank() == 0: - get_file(dataset_uri, destination_path, overwrite=True) - dataset = load_dataset('json', data_files=destination_path, split='train', streaming=False) - self.samples = list( - dataset.map(lambda examples: { - 'query': examples['query'], - 'choices': examples['choices'], - 'gold': examples['gold'] - })) - self.samples = strip_data(self.samples) - - self.num_choices = len(self.samples[0]['choices']) - self.tokenizer = tokenizer - self.max_seq_len = max_seq_len - self.pad_tok_id = pad_tok_id - fewshot_rng = random.Random(fewshot_random_seed) - - self.prefix_space = _tokenizer_needs_prefix_space(self.tokenizer) - - self.encoded_dataset = self.prep_examples(num_fewshot, prompt_string, example_delimiter, continuation_delimiter, - fewshot_rng) - - def prep_examples(self, num_fewshot: int, prompt_string: str, example_delimiter: str, continuation_delimiter: str, - fewshot_rng: random.Random): - """Prepares a set of multiple choice questions into tokenized format with prompt and few shot examples. - - Each question consists of a query and set of answer choices, only one of which is correct. At inference time - we construct individual inference examples consisting of the query + a single choice, as well as an optional (prompt) and optional list - of example query + correct answers, which precede the test query + choice. + def __init__(self, + choices_key: str = 'choices', + static_keys: Optional[List] = None, + list_of_tensors_keys: Optional[List] = None, + list_of_tuples_keys: Optional[List] = None, + list_of_primitives: Optional[List] = None, + *args, + **kwargs): + self.choices_key = choices_key + base_batch = { + 'input_ids': [], + 'continuation_indices': [], + 'mode': 'icl_task', + 'labels': [], + 'gold_indices': [], + 'choice_groupings': [], + } + context_key = kwargs.pop('context_key', 'query') + static_keys = kwargs.pop('static_keys', ['mode', 'generation_kwargs']) + tensor_keys = kwargs.pop('tensor_keys', ['input_ids', 'labels', 'attention_mask']) + self.list_of_tensors_keys = list_of_tensors_keys or ['continuation_indices'] + self.list_of_tuples_keys = list_of_tuples_keys or ['choice_groupings'] + self.list_of_primitives = list_of_primitives or ['gold_indices'] + super().__init__(context_key=context_key, + base_batch=base_batch, + static_keys=static_keys, + tensor_keys=tensor_keys, + padding_side='right', + *args, + **kwargs) + self.num_choices = len(self.dataset[0][self.choices_key]) + self.batch_mapping_per_choice = {'input_ids': 'context', 'labels': 'context'} + self.batch_map_per_example = {'gold_indices': 'gold'} + + def get_answer_from_example(self, example: Dict, in_context=False) -> str: + """ + Returns the correct answer from the example's choices. + Args: + example (Dict): The example from which to retrieve the answer - For multiple choice, this method provides information relaying which of the answer choices is the correct one. This - information is used for computing accuracy metrics. + Returns: + str: The full string of the correct answer based on the 'gold' key + """ + choices = example[self.choices_key] + gold_idx = example['gold'] + return choices[gold_idx] + def tokenize_example(self, prompt_and_fewshot: str, ctxt: str, example: Dict) -> Dict[str, Any]: + """ + Runs text through the tokenizer and handle special cases. Args: - num_fewshot (int): Number of examples context/continuation pairs to prepend to the test pair - prompt_string (str): The prompt to prepend to all inputs - example_delimiter (str): The delimiter used to separate each example query/answer pair - continuation_delimiter (str): The delimiter used to separate each query from its answer - fewshot_rng (random.Random): Random number generator used to select fewshot examples + prompt_and_fewshot (str): The collection of the prompt and fewshot examples that belongs before the example's context + ctx (str): The specific example's derrived context + example (Dict): The example as a dictionary. Returns: - dict: Contains the query, the list of encoded potential answer choices, the preamble (prompt + fewshot examples), and - the index of the correct answer choice. + Dict: Dictionary with the tokenized data """ - examples = [] - for sample_idx in tqdm(range(len(self.samples))): - - preamble = prompt_string - if num_fewshot > 0: - fewshot_idxs = _get_fewshot_sample_idxs(len(self.samples), num_fewshot, sample_idx, fewshot_rng) - for fewshot_idx in fewshot_idxs: - query, choices, gold_idx = self.samples[fewshot_idx]['query'], self.samples[fewshot_idx][ - 'choices'], self.samples[fewshot_idx]['gold'] - if len(preamble) > 0: - query = f'{example_delimiter}{query}' - assert isinstance(gold_idx, int) - preamble += f'{query}{continuation_delimiter}{choices[gold_idx]}' - encoded_example = {} - query, choices, gold_idx = self.samples[sample_idx]['query'], self.samples[sample_idx][ - 'choices'], self.samples[sample_idx]['gold'], - if len(preamble) > 0: - query = f'{example_delimiter}{query}' - - # rstrip the continuation delimiter, because the prompt ending in a space results in degenerate output - continuation_delimiter_stripped = continuation_delimiter.rstrip() - + # NOTE: some of this is repeated from super class but for loop makes things considerably different + tokenized_example = {} + # Always add special tokens to preamble + preamble = self.tokenizer(prompt_and_fewshot)['input_ids'] + assert isinstance(preamble, list) + preamble = self._fix_eos_on_preamble(preamble) + if self.strip_data: + # rstrip context because a prompt ending in a space results in degenerate output + ctxt = ctxt.rstrip() + # Never add special tokens to context + tokenized_context = self.tokenizer(ctxt, add_special_tokens=False)['input_ids'] + assert isinstance(tokenized_context, list) + tokenized_context = preamble + tokenized_context + + tokenized_example[self.context_key] = [] + tokenized_example[self.answer_key] = [] + tokenized_example['continuation_indices'] = [] + # NOTE: Treating tokenize_labels as True for all MC datasets (required for our MC accuracy metric) + for choice in example[self.choices_key]: if self.prefix_space: - choices = [(f' {choice}' if not choice.startswith(' ') else choice) for choice in choices] - query += continuation_delimiter_stripped - encoded_example['preamble'] = self.tokenizer( - preamble - ) # if the preamble is empty then these will be 0-length lists, unless the tokenizer adds special tokens to empty strings (e.g. OPT tokenizer) - - example_ids = encoded_example['preamble']['input_ids'] - if (self.tokenizer.eos_token_id is not None and - len(example_ids) > 1 and # pyright: ignore[reportGeneralTypeIssues] - example_ids[-1] == self.tokenizer.eos_token_id): # pyright: ignore[reportGeneralTypeIssues] - encoded_example['preamble']['input_ids'] = example_ids[:-1] # pyright: ignore[reportGeneralTypeIssues] - - encoded_example['gold_idx'] = gold_idx - - encoded_example['query'] = self.tokenizer(query, add_special_tokens=False) - encoded_example['choices'] = [self.tokenizer(choice, add_special_tokens=False) for choice in choices] - - examples.append(encoded_example) + choice = f' {choice}' if not choice.startswith(' ') else choice + + # Never add special tokens to answer + tokenized_answer = self.tokenizer(choice, add_special_tokens=False)['input_ids'] + assert isinstance(tokenized_context, list) + assert isinstance(tokenized_answer, list) + trimmed_context = _trim_context(tokenized_context, tokenized_answer, self.padding_size) + assert isinstance(trimmed_context, list) + continuation_indices = _get_continuation_span(trimmed_context, tokenized_answer) + padded_context = _make_padded_input( + trimmed_context, + tokenized_answer, + self.padding_size, + self.pad_tok_id, + self.padding_side, + ) - return examples + tokenized_example[self.context_key].append(padded_context) + tokenized_example[self.answer_key].append(tokenized_answer) + tokenized_example['continuation_indices'].append(continuation_indices) - def __getitem__(self, index): - return self.encoded_dataset[index] + tokenized_example['gold'] = example['gold'] + return tokenized_example - def __len__(self): - return len(self.encoded_dataset) + def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + The function that the dataloader uses to accumulate data into batches. + We run each distinct query + answer choice through the model separately and determine which + answer has the lowest per-token-perplexity. + + If each question has N possible choices, all N must be grouped together as distinct elements of the batch + since the batch may consist of multiple questions, the choice_groupings indicates + which contiguous sequences of elements in the batch correspond to which question + gold_indices indicates which of the [0, N-1] choices is the correct one for each question. + Args: + data (List): List of tokenized datapoints (dicts returned by self._tokenize_example) - def collate_fn(self, data): - inputs = [] - continuation_indices = [] - gold_idxs = [] - choice_groupings = [] + Returns: + Dict: Dictionary for a single batch + """ + batch = copy.deepcopy(self.base_batch) for data_pair in data: - - choice_start_idx = len(continuation_indices) - preamble, context, choices, gold_idx = (data_pair['preamble'], data_pair['query'], data_pair['choices'], - data_pair['gold_idx']) - - for choice in choices: - context_enc = preamble['input_ids'] + context['input_ids'] - continuation_enc = choice['input_ids'] - inp, continuation_span = _make_padded_input(context_enc, continuation_enc, self.max_seq_len, - self.pad_tok_id) - - inputs.append(inp) - continuation_indices.append(continuation_span) - - gold_idxs.append(gold_idx) - choice_end_idx = len(continuation_indices) - choice_groupings.append((choice_start_idx, choice_end_idx)) - - # We run each distinct query + answer choice through the model separately and determine which - # answer has the lowest per-token-perplexity. - # - # If each question has N possible choices, all N must be grouped together as distinct elements of the batch - # since the batch may consist of multiple questions, the choice_groupings indicates - # which contiguous sequences of elements in the batch correspond to which question - # gold_indices indicates which of the [0, N-1] choices is the correct one for each question. - batch = { - 'input_ids': torch.stack(inputs), - 'continuation_indices': continuation_indices, - 'mode': 'icl_task', - 'labels': torch.stack(inputs), - 'gold_indices': gold_idxs, - 'choice_groupings': choice_groupings - } + choice_start_idx = len(batch['continuation_indices']) + # NOTE: not using batch_mapping + for i, context_enc in enumerate(data_pair[self.context_key]): + batch['input_ids'].append(context_enc) + batch['continuation_indices'].append(data_pair['continuation_indices'][i]) + batch['labels'].append(context_enc) + + batch['gold_indices'].append(data_pair['gold']) + choice_end_idx = len(batch['continuation_indices']) + batch['choice_groupings'].append((choice_start_idx, choice_end_idx)) + + batch = convert_tokens_to_tensors(batch, self.tokenize_labels) batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id) return batch def get_num_samples_in_batch(self, batch) -> int: return batch['input_ids'].shape[0] // self.num_choices - def split_batch(self, batch: Any, microbatch_size: int): - """Split batch while ensuring all continuations are in the same microbatch. + def split_batch(self, batch: Any, microbatch_size: int) -> List[Dict[str, Any]]: + """ + Split batch while ensuring all continuations are in the same microbatch. In ICL Multiple Choice, we duplicate each data point for each possible continuation. - When splitting a batch, we have logical samples, which refer to one possible question, - and real samples, which refers to one possible continuation. As sample count and - microbatch_size are tracked in logical samples, we split logical attributes by + When splitting a batch, we have logical example, which refer to one possible question, + and real example, which refers to one possible continuation. As example count and + microbatch_size are tracked in logical example, we split logical attributes by microbatch_size and real attributes by microbatch_size * num_choices. + Args: + batch (Dict): Batch of data + microbatch_size (int): Size of microbatches + + Returns: + list: List of chunked batches """ - # Don't split kwargs that don't change - # Normally split torch tensors - # List split lists of strings - no_split = ['mode'] - # Real - real = ['input_ids', 'labels', 'attention_mask'] - logical = ['gold_indices'] chunked = {} for k, v in batch.items(): - if k in no_split: + if k in self.static_keys: # Defer broadcasting primitives until we know num_chunks pass - elif k == 'continuation_indices': - # List of list, so we have to directly call _split_list - chunked[k] = _split_list(v, microbatch_size * self.num_choices) - elif k == 'choice_groupings': - # List of list, so we have to directly call _split_list - chunked[k] = _split_list(v, microbatch_size) - elif k in real: + elif type(v) == list: + # list of tensors - 'continuation_indices' + if k in self.list_of_tensors_keys: + chunked[k] = _split_list(v, microbatch_size * self.num_choices) + # list of tuples - 'choice_groupings' + elif k in self.list_of_tuples_keys: + chunked[k] = _split_list(v, microbatch_size) + # list - 'gold_indices' + elif k in self.list_of_primitives: + chunked[k] = _default_split_batch(v, microbatch_size) + else: + raise ValueError(f'Unexpected key {k} in list splitting') + elif k in self.tensor_keys: chunked[k] = _default_split_batch(v, microbatch_size * self.num_choices) - elif k in logical: - chunked[k] = _default_split_batch(v, microbatch_size) else: - raise ValueError(f'Unexpected key {k}') + raise ValueError(f'Unexpected key {k} in batch splitting') num_chunks = len(chunked['input_ids']) # Broadcast primitives to all chunks for k, v in batch.items(): - if isinstance(v, (int, float, str, bool)): + if k in self.static_keys: chunked[k] = [v] * num_chunks + return [{k: v[idx] for k, v in chunked.items()} for idx in range(num_chunks)] class InContextLearningSchemaTaskDataset(InContextLearningMultipleChoiceTaskDataset): - """A dataset that constructs batches for in-context learning schema evaluation + """A dataset that constructs batches for in-context learning schema evaluation. A schema task involves sentences with a fill-in-the-blank where the user needs to choose the correct word to fill in from a set of N options. We use the partial evaluation technique from https://arxiv.org/abs/1806.02847 to determine the model's choice of fill-in word. - Each batch then consists of batch_size // N distinct tasks and has the following the structure - 'input_ids': Input tensor batch x seqlen x # tokens - 'continuation_indices': List of |batch| consisting of tensors indicating which indices in the sequence correspond to the question answer (aka continuation) - 'mode': Indicates to the model that this is an ICL task and may rely on a custom code path to properly update metrics - 'labels': Identical to the input, used by the model to calculate loss/metrics - 'gold_indices': List of length |batch_size // N| indicating for each question, which of the answers is correct (via an integer [0, N-1]) - 'choice_groupings': Indicates which indices of the batch correspond to which questions - Args: - dataset_uri (str): Either a local path, or a remote path beginning with ``s3://``, or another backend - supported by :meth:`composer.utils.maybe_create_object_store_from_uri`. Dataset must consist of rows of JSON data points with "query", - "choices", and "gold" index. See tests/datasets/local_data/piqa_small.jsonl. - tokenizer (Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast]): The tokenizer used to transform data into batches - batch_size (int): Size of a batch used for eval - max_seq_len (int): The sequence length expected by the model - pad_tok_id (int): The special token reserved for padding the ends of batches - num_fewshot (int): The number of complete fewshot examples to prepend before each test example - prompt_string (str): Prompt string to put once before all fewshot examples/test examples (e.g. 'translate english to french') - example_delimiter (str): Separator that goes between individual (context, continuation) pairs (e.g. '\n') - continuation_delimiter: (str): Separator that goes between context and continuation in each example (e.g. '->') - destination_path (str): Temporary path to store downloaded datasets - fewshot_random_seed (int): Random seed used to select fewshot examples + + The default input format is a jsonl file with the following fields: + - context_options: List of strings corresponding to possible preceding context options for the continuation + - gold: Index of the correct context from 'context_options' + - continuation: The finishing continuation + + Each batch then consists of ``batch_size // N`` distinct tasks and has the following the structure + - input_ids: Input tensor ``batch x seqlen x # of tokens`` + - continuation_indices: List of ``batch`` consisting of tensors indicating which indices in the sequence correspond to the question answer (aka continuation) + - mode: Indicates to the model that this is an ICL task and may rely on a custom code path to properly update metrics + - labels: Identical to the input, used by the model to calculate loss/metrics + - gold_indices: List of length ``batch_size // N`` indicating for each question, which of the answers is correct (via an integer [0, N-1]) + - choice_groupings: Indicates which indices of the batch correspond to which questions + """ - def __init__( - self, - dataset_uri: str, - tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast], - max_seq_len: int, - pad_tok_id: int, - num_fewshot: int, - prompt_string: str, - example_delimiter: str, - continuation_delimiter: str, - destination_path: str, - fewshot_random_seed: int, - ): - try: - from datasets import load_dataset # pyright: ignore [reportGeneralTypeIssues] - except ImportError as e: - raise MissingConditionalImportError(extra_deps_group='nlp', - conda_package='datasets', - conda_channel='conda-forge') from e + def __init__(self, choices_key='context_options', *args, **kwargs): + static_keys = ['mode'] + tensor_keys = ['input_ids', 'labels', 'attention_mask'] + list_of_tensors_keys = ['continuation_indices'] + super().__init__(choices_key=choices_key, + context_key=choices_key, + static_keys=static_keys, + tensor_keys=tensor_keys, + list_of_tensors_keys=list_of_tensors_keys, + *args, + **kwargs) + self.base_batch = { + 'input_ids': [], + 'continuation_indices': [], + 'mode': 'icl_task', + 'labels': [], + 'gold_indices': [], + 'choice_groupings': [], + } - with dist.local_rank_zero_download_and_wait(destination_path): - if dist.get_local_rank() == 0: - get_file(dataset_uri, destination_path, overwrite=True) - dataset = load_dataset('json', data_files=destination_path, split='train', streaming=False) - self.samples = list( - dataset.map( - lambda examples: { - 'context_options': examples['context_options'], - 'continuation': examples['continuation'], - 'gold': examples['gold'] - })) - self.samples = strip_data(self.samples) - - self.num_choices = len(self.samples[0]['context_options']) - self.tokenizer = tokenizer - self.max_seq_len = max_seq_len - self.pad_tok_id = pad_tok_id - fewshot_rng = random.Random(fewshot_random_seed) + def construct_context(self, example, preceding_text: str = '', add_answer: bool = False) -> str: + """ + Takes a example and constructs a context with the correct context for the example's continuation. - self.prefix_space = _tokenizer_needs_prefix_space(self.tokenizer) + Args: + example (Dict): The example from which to construct the context + preceding_text (str): Any preceding text, needed to if self.example_delimiter is needed at the beginning + add_answer (bool): This will always be true when calling this function for SchemaTaskDataset + + Returns: + str: The single correct context for a given continuation + """ + context_options = example[self.choices_key] + gold_idx = example['gold'] + continuation = example['continuation'] + context = context_options[gold_idx] + if len(preceding_text) > 0: + context = f'{self.example_delimiter}{context}' + context = f'{context}{self.continuation_delimiter}{continuation}' + return context + + def _construct_multiple_contexts(self, example: Dict, preceding_text: str = '') -> List[str]: + """ + Takes a example and constructs all contexts. Optionally, appends this to preceeding text (such as a + prompt or fewshot examples). - self.encoded_dataset = self.prep_examples(num_fewshot, prompt_string, example_delimiter, continuation_delimiter, - fewshot_rng) - - def prep_examples(self, num_fewshot: int, prompt_string: str, example_delimiter: str, continuation_delimiter: str, - fewshot_rng: random.Random): - """Prepares a set of schema questions into tokenized format with prompt and few shot examples. - Each question consists of a set of possible contexts followed by a continuation, only one of the contexts would logically permit the continuation. - At inference time we construct individual inference examples consisting of a single context option + the continuation, - as well as an optional (prompt) and optional list of example correct context option + continuations, which precede the test context option + continuation. - For schema, this method provides information relaying which of the answer choices is the correct one. This - information is used for computing accuracy metrics. Args: - num_fewshot (int): Number of examples context/continuation pairs to prepend to the test pair - prompt_string (str): The prompt to prepend to all inputs - example_delimiter (str): The delimiter used to separate each example query/answer pair - continuation_delimiter (str): The delimiter used to separate each query from its answer - fewshot_rng (random.Random): Random number generator used to select fewshot examples + example (Dict): The example from which to construct the context + preceding_text (str): Any preceding text, needed to if self.example_delimiter is needed at the beginning + Returns: - dict: Contains the query, the list of encoded potential answer choices, the preamble (prompt + fewshot examples), and - the index of the correct answer choice. + list: All context options for the selected example with formatting """ + context_options = example[self.choices_key] + if len(preceding_text) > 0: + if self.strip_data: + cont_del = self.continuation_delimiter.rstrip() + else: + cont_del = self.continuation_delimiter + context_options = [f'{self.example_delimiter}{c}{cont_del}' for c in context_options] + return context_options - examples = [] - for sample_idx in tqdm(range(len(self.samples))): - - preamble = prompt_string - if num_fewshot > 0: - fewshot_idxs = _get_fewshot_sample_idxs(len(self.samples), num_fewshot, sample_idx, fewshot_rng) - for fewshot_idx in fewshot_idxs: - context_options, continuation, gold_idx = self.samples[fewshot_idx][ - 'context_options'], self.samples[fewshot_idx]['continuation'], self.samples[fewshot_idx]['gold'] - assert isinstance(gold_idx, int) - context = context_options[gold_idx] - if len(preamble) > 0: - context = f'{example_delimiter}{context}' - preamble += f'{context}{continuation_delimiter}{continuation}' - - encoded_example = {} - context_options, continuation, gold_idx = self.samples[sample_idx]['context_options'], self.samples[ - sample_idx]['continuation'], self.samples[sample_idx]['gold'], - - # rstrip the continuation delimiter, because the prompt ending in a space results in degenerate output - continuation_delimiter_stripped = continuation_delimiter.rstrip() - - if len(preamble) > 0: - context_options = [f'{example_delimiter}{c}{continuation_delimiter_stripped}' for c in context_options] - encoded_example['preamble'] = self.tokenizer( - preamble - ) # if the preamble is empty then these will be 0-length lists, unless the tokenizer adds special tokens to empty strings (e.g. OPT tokenizer) - example_ids = encoded_example['preamble']['input_ids'] - if (self.tokenizer.eos_token_id is not None and - len(example_ids) > 1 and # pyright: ignore[reportGeneralTypeIssues] - example_ids[-1] == self.tokenizer.eos_token_id): # pyright: ignore[reportGeneralTypeIssues] - encoded_example['preamble']['input_ids'] = example_ids[:-1] # pyright: ignore[reportGeneralTypeIssues] - - encoded_example['gold_idx'] = gold_idx - encoded_example['context_options'] = [self.tokenizer(c, add_special_tokens=False) for c in context_options] + def _prep_example( + self, + example: Dict, + example_idx: int, + num_fewshot: int, + prompt_string: str, + fewshot_rng: random.Random, + ) -> Dict[str, Any]: + """ + Prepares a single example from a HF Dataset into tokenized format with prompt and fewshot examples. - if self.prefix_space: - continuation = f' {continuation}' if not continuation.startswith(' ') else continuation - encoded_example['continuation'] = self.tokenizer(continuation, add_special_tokens=False) - examples.append(encoded_example) + Each task consists of multiple contexts and a single, correct continuation. Will preprend fewshot examples and + prompt if present. - return examples + Args: + example (Dict): A dictionary from the hf dataset + example_idx (int): The index of example + num_fewshot (int): Number of examples context/continuation pairs to prepend to the test pair + prompt_string (str): The prompt to prepend to all inputs + fewshot_rng (random.Random): Random number generator to use for fewshot sampling - def collate_fn(self, data): - inputs = [] - continuation_indices = [] - gold_idxs = [] - choice_groupings = [] - for data_pair in data: + Returns: + Dict: Contains a dictionary with the tokenized data + """ + prompt_and_fewshot = self._generate_few_shot_prompt(num_fewshot, example_idx, prompt_string, fewshot_rng) + ctxt = self._construct_multiple_contexts(example, prompt_and_fewshot) + tokenized_example = self.tokenize_example(prompt_and_fewshot, ctxt, example) + return tokenized_example - continuation_start_idx = len(continuation_indices) - preamble, context_options, continuation, gold_idx = (data_pair['preamble'], data_pair['context_options'], - data_pair['continuation'], data_pair['gold_idx']) - - for ctxt in context_options: - context_enc = preamble['input_ids'] + ctxt['input_ids'] - continuation_enc = continuation['input_ids'] - inp, continuation_span = _make_padded_input(context_enc, continuation_enc, self.max_seq_len, - self.pad_tok_id) - - inputs.append(inp) - continuation_indices.append(continuation_span) - - gold_idxs.append(gold_idx) - continuation_end_idx = len(continuation_indices) - choice_groupings.append((continuation_start_idx, continuation_end_idx)) - - # We run each distinct query + answer choice through the model separately and determine which - # answer has the lowest per-token-perplexity. - # - # If each question has N possible choices, all N must be grouped together as distinct elements of the batch - # since the batch may consist of multiple questions, the choice_groupings indicates - # which contiguous sequences of elements in the batch correspond to which question - # gold_indices indicates which of the [0, N-1] choices is the correct one for each question. - batch = { - 'input_ids': torch.stack(inputs), - 'continuation_indices': continuation_indices, - 'mode': 'icl_task', - 'labels': torch.stack(inputs), - 'gold_indices': gold_idxs, - 'choice_groupings': choice_groupings - } - batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id) - return batch + def tokenize_example(self, prompt_and_fewshot: str, context_options: List[str], example: Dict) -> Dict[str, Any]: + """ + Runs text through the tokenizer and handle special cases. + Args: + prompt_and_fewshot (str): The collection of the prompt and fewshot examples that belongs before the example's context + ctx (str): The specific example's derrived context + example (Dict): The example as a dictionary. -class InContextLearningCodeEvalDataset(Dataset): - """ A dataset that constructs batches for in-context learning code evaluation + Returns: + Dict: Dictionary with the tokenized data + """ + tokenized_example = {} + preamble = self.tokenizer(prompt_and_fewshot)['input_ids'] + assert isinstance(preamble, list) + preamble = self._fix_eos_on_preamble(preamble) + encoded_contexts = [ + preamble + # pyright: ignore[reportOperatorIssue, reportGeneralTypeIssues] + self.tokenizer(c, add_special_tokens=False)['input_ids'] # pyright: ignore[reportOperatorIssue, ] + for c in context_options + ] + continuation = example['continuation'] + if self.prefix_space: + continuation = (f' {continuation}' if not continuation.startswith(' ') else continuation) + tokenized_continuation = self.tokenizer(continuation, add_special_tokens=False)['input_ids'] + + tokenized_example[self.context_key] = [] + tokenized_example['continuation_indices'] = [] + tokenized_example[self.answer_key] = [] + for context in encoded_contexts: + assert isinstance(context, list) + assert isinstance(tokenized_continuation, list) + trimmed_context = _trim_context(context, tokenized_continuation, self.padding_size) + assert isinstance(trimmed_context, list) + continuation_indices = _get_continuation_span(trimmed_context, tokenized_continuation) + padded_context = _make_padded_input(trimmed_context, tokenized_continuation, self.padding_size, + self.pad_tok_id, self.padding_side) + tokenized_example[self.context_key].append(padded_context) + tokenized_example['continuation_indices'].append(continuation_indices) + tokenized_example[self.answer_key].append(tokenized_continuation) + + tokenized_example['gold'] = example['gold'] + return tokenized_example + + +class InContextLearningCodeEvalDataset(InContextLearningDataset): + """ + A dataset that constructs batches for in-context learning code evaluation. The input format is expected to be a jsonl file with the following fields: - - task_id: label of given task - - prompt: the code snippet that must be completed - - entry_point: the entry to the function/code snippet to generate - - canonical_solution: working solution - - test: the checker code that will run to completion if the code generation is valid and otherwise throw assertion - - test_inputs: list of test inputs - - test_outputs: list of test outputs - - language: the language of the code snippet - Args: - dataset_uri (str): Either a local path, or a remote path beginning with ``s3://``, or another backend - supported by :meth:`composer.utils.maybe_create_object_store_from_uri`. Dataset must consist of rows of JSON data points with "task_id", - "prompt", "entry_point", "canonical_solution", "test", "test_inputs", and "test_outputs". See tests/datasets/local_data/human_eval_small.jsonl. - tokenizer (Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast]): The tokenizer used to map between strings and token ids - batch_size (int): Size of a batch used for eval - max_seq_len (int): The maximum sequence length supported by the model - pad_tok_id (int): The special token reserved for padding batches - num_fewshot (int): The number of complete fewshot examples to prepend before each test example - prompt_string (str): Prompt string to put once before all fewshot examples/test examples (e.g. 'translate english to french') - example_delimiter (str): Separator that goes between individual (context, answer) pairs (e.g. '\n') - destination_path (str): Temporary path to store downloaded datasets - code_prelimiter (str): String to put before each code prompt (e.g. 'Q: ') - fewshot_random_seed (int): Random seed to use for fewshot sampling - generations_per_sample: how many outputs to generate per prompt - top_p: top_p sampling parameter for nucleus sampling - top_k: top_k sampling parameter for number of samples to consider + + - task_id: Label of given task + - prompt: The code snippet that must be completed + - entry_point: The entry to the function/code snippet to generate + - canonical_solution: Working solution + - test: The checker code that will run to completion if the code generation is valid and otherwise throw assertion + - test_inputs: List of test inputs + - test_outputs: List of test outputs + - language: The language of the code snippet + + Each batch then consists of the following the structure + + - input_ids: Input tensor batch x seqlen x num tokens + - mode: Indicates to the model that this is an ICL task and may rely on a custom code path to properly update metrics + - mode: Always set to 'generate' + - labels: Exact solution for the coding problem + - prompts: Prompt for the task + - entry_points: List of entry points + - test_inputs: List of test inputs + - test_outputs: List of test outputs + - languages: List of languages + - pass_at_k: Passed value for pass_at_k + - generation_length: Derrived maximum generation length + - generation_kwargs: Dictionary of kwargs neeeded for generation. Includes the following, which will be individually overwritten + by keys in generaiton_kwargs if set (see https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig + for more details): + + - pad_token_id: ID for padding token, derived automatically + - num_beams: How many beams to search for generations, set to 1 + - num_return_sequences: Value passed for 'generations_per_sample', how many generations per prompt + - do_sample: Determines whether model is sampling or greedily decoding. Always set to True + - use_cache: Whether or not to use past key values to speed up sampling. Always set to True + + Additional Args: + generations_per_sample (int) (defaults to 1): The number of independently computed returned sequences for each element in the batch + pass_at_k (int) (defaults to 1): k for how many chances the model gets to write passing code """ def __init__( self, - dataset_uri: str, - tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast], - max_seq_len: int, - pad_tok_id: int, - num_fewshot: int, - prompt_string: str, - example_delimiter: str, - destination_path: str, - code_prelimiter: str, - fewshot_random_seed: int, generations_per_sample: int, pass_at_k: int = 1, - top_p: Optional[float] = 0.95, - top_k: Optional[int] = 40, + *args, + **kwargs, ): - if tokenizer.eos_token_id is None: - raise ValueError('`InContextLearningCodeEvalDataset` tokenizer must have non-null `eos_token_id`') - try: - from datasets import load_dataset # pyright: ignore [reportGeneralTypeIssues] - except ImportError as e: - raise MissingConditionalImportError(extra_deps_group='nlp', - conda_package='datasets', - conda_channel='conda-forge') from e - with dist.local_rank_zero_download_and_wait(destination_path): - if dist.get_local_rank() == 0: - get_file(dataset_uri, destination_path, overwrite=True) - dataset = load_dataset('json', data_files=destination_path, split='train', streaming=False) - self.samples = list( - dataset.map( - lambda examples: { - 'task_id': examples['task_id'], - 'prompt': examples['prompt'], - 'canonical_solution': examples['canonical_solution'], - 'test': examples['test'], - 'entry_point': examples['entry_point'], - 'test_inputs': examples['test_inputs'], - 'test_outputs': examples['test_outputs'], - 'language': examples['language'], - })) - if generations_per_sample < pass_at_k: raise ValueError( f'generations_per_sample ({generations_per_sample}) must be greater than or equal to pass_at_k ({pass_at_k}) for code evaluation.' ) - - self.pass_at_k = pass_at_k - self.generations_per_sample = generations_per_sample - - self.tokenizer = tokenizer - self.max_seq_len = max_seq_len - self.pad_tok_id = pad_tok_id - self.padding_side = 'left' + batch_mapping = { + 'input_ids': 'prompt', + 'prompts': 'prompt_text', + 'tests': 'test', + 'labels': 'canonical_solution', + 'entry_points': 'entry_point', + 'test_inputs': 'test_inputs', + 'test_outputs': 'test_outputs', + 'languages': 'language' + } + # Linting complains if these are not set in init self.max_prompt_length = 0 - self.top_p = top_p - self.top_k = top_k self.max_answer_length = 0 - fewshot_rng = random.Random(fewshot_random_seed) - self.encoded_dataset = self.prep_examples(num_fewshot, prompt_string, example_delimiter, code_prelimiter, - fewshot_rng) - - def prep_examples(self, num_fewshot: int, prompt_string: str, example_delimiter: str, code_prelimiter: str, - fewshot_rng: random.Random): - """Prepares a set of code evaluation tasks into tokenized format with prompt and fewshot examples. - - Each task consists of a context as well as an optional prompt and optional list of - example context/continuation pairs which precede the test context/continuation pair. - - Args: - num_fewshot (int): Number of examples context/continuation pairs to prepend to the test pair - prompt_string (str): The prompt to prepend to all inputs - example_delimiter (str): The delimiter used to separate each individual context/continuation pair - code_prelimiter (str): The text to prepend to each code prompt - fewshot_rng (random.Random): Random number generator to use for fewshot sampling - - Returns: - dict: Contains the context, the continuation, and the preamble (prompt + fewshot examples) - """ - max_prompt_length = 0 - examples = [] - max_answer_length = 0 - for sample_idx in tqdm(range(len(self.samples))): - encoded_example = {} - - preamble = prompt_string - - if num_fewshot > 0: - fewshot_idxs = _get_fewshot_sample_idxs(len(self.samples), num_fewshot, sample_idx, fewshot_rng) - for fewshot_idx in fewshot_idxs: - ctxt, cont = self.samples[fewshot_idx]['prompt'], self.samples[fewshot_idx]['canonical_solution'] - ctxt = f'{code_prelimiter}{ctxt}' - if len(preamble) > 0: - ctxt = f'{example_delimiter}{ctxt}' - preamble += f'{ctxt}{cont}' - - ctxt = self.samples[sample_idx]['prompt'] - ctxt = f'{code_prelimiter}{ctxt}' - if len(preamble) > 0: - ctxt = f'{example_delimiter}{ctxt}' - - # If the preamble is empty then this will be a 0-length list, unless the tokenizer adds special tokens to empty strings (e.g. OPT tokenizer) - encoded_example['preamble'] = self.tokenizer(preamble) - # If there is an EOS token added, we need to remove it so it is not in the middle of the prompt - example_ids = encoded_example['preamble']['input_ids'] - if (self.tokenizer.eos_token_id is not None and - len(example_ids) > 1 and # pyright: ignore[reportGeneralTypeIssues] - example_ids[-1] == self.tokenizer.eos_token_id): # pyright: ignore[reportGeneralTypeIssues] - encoded_example['preamble']['input_ids'] = example_ids[:-1] # pyright: ignore[reportGeneralTypeIssues] - - encoded_example['prompt'] = self.tokenizer(ctxt, add_special_tokens=False) - encoded_example['prompt_text'] = self.samples[sample_idx]['prompt'] - encoded_example['task_id'] = self.samples[sample_idx]['task_id'] - encoded_example['canonical_solution'] = self.samples[sample_idx]['canonical_solution'] - encoded_example['test'] = self.samples[sample_idx]['test'] - encoded_example['entry_point'] = self.samples[sample_idx]['entry_point'] - encoded_example['test_inputs'] = self.samples[sample_idx]['test_inputs'] - encoded_example['test_outputs'] = self.samples[sample_idx]['test_outputs'] - encoded_example['language'] = self.samples[sample_idx]['language'] - - examples.append(encoded_example) - max_prompt_length = max( - max_prompt_length, - len(encoded_example['preamble']['input_ids'] + - encoded_example['prompt']['input_ids'])) # pyright: ignore[reportGeneralTypeIssues] - max_answer_length = max( - max_answer_length, - len(self.tokenizer(encoded_example['canonical_solution'], - add_special_tokens=False)['input_ids'])) # pyright: ignore[reportGeneralTypeIssues] - - self.max_prompt_length = max_prompt_length - self.max_answer_length = max_answer_length + _MAX_ANSWER_BUFFER_LENGTH - return examples - - def __getitem__(self, index): - return self.encoded_dataset[index] - - def __len__(self): - return len(self.encoded_dataset) - - def collate_fn(self, data): - inputs, prompts, tests, canonical_solutions, entry_points, test_inputs, test_outputs, languages = [], [], [], [], [], [], [], [] - for sample in data: - preamble, prompt, text_prompt, canonical_solution, test, entry_point, test_input, test_output, language = ( - sample['preamble'], - sample['prompt'], - sample['prompt_text'], - sample['canonical_solution'], - sample['test'], - sample['entry_point'], - sample['test_inputs'], - sample['test_outputs'], - sample['language'], - ) - context_enc = preamble['input_ids'] + prompt['input_ids'] - inp, _ = _make_padded_input(context_enc, [], - self.max_prompt_length, - self.pad_tok_id, - padding_side=self.padding_side) - - inputs.append(inp) - tests.append(test) - prompts.append(text_prompt) - canonical_solutions.append(canonical_solution) - entry_points.append(entry_point) - test_inputs.append(test_input) - test_outputs.append(test_output) - languages.append(language) - - batch = { - 'input_ids': torch.stack(inputs), + static_keys = ['mode', 'pass_at_k', 'generation_length', 'generation_kwargs'] + list_keys = ['prompts', 'tests', 'entry_points', 'test_inputs', 'test_outputs', 'languages', 'labels'] + tensor_keys = ['input_ids', 'attention_mask'] + super().__init__( + context_key='prompt', + answer_key='canonical_solution', + strip_dataset=False, + static_keys=static_keys, + list_keys=list_keys, + tensor_keys=tensor_keys, + tokenize_labels=False, + padding_side='left', + batch_mapping=batch_mapping, + *args, + **kwargs, + ) + self._set_max_prompt_and_answer_lengths() + self.dataset = self.dataset.map(self._trim_padding) + self.base_batch = { + 'input_ids': [], 'mode': 'generate', - 'labels': canonical_solutions, - 'prompts': prompts, # list of prompts - 'tests': tests, # list of tests - 'canonical_solutions': canonical_solutions, # list of solutions - 'entry_points': entry_points, # list of entry points - 'test_inputs': test_inputs, # list of test inputs - 'test_outputs': test_outputs, # list of test outputs - 'languages': languages, # list of languages - 'pass_at_k': self.pass_at_k, + 'labels': [], + 'prompts': [], + 'tests': [], + 'entry_points': [], + 'test_inputs': [], + 'test_outputs': [], + 'languages': [], + 'pass_at_k': pass_at_k, 'generation_length': min(self.max_answer_length, self.max_seq_len - self.max_prompt_length), 'generation_kwargs': { 'pad_token_id': self.pad_tok_id, 'num_beams': 1, # single beam - 'num_return_sequences': self.generations_per_sample, # how many gens per prompt + 'num_return_sequences': generations_per_sample, 'do_sample': True, - 'top_p': self.top_p, - 'top_k': self.top_k, 'use_cache': True, 'eos_token_id': self.tokenizer.eos_token_id } } - batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id) - return batch + self.update_generation_kwargs(kwargs.get('generation_kwargs', {})) - def get_num_samples_in_batch(self, batch) -> int: - # Count number of inputs in the batch - return batch['input_ids'].shape[0] + def _set_max_prompt_and_answer_lengths(self): + """ + Iterates through the dataset and finds the maximum prompt length and sequence lengths - def split_batch(self, batch: Any, microbatch_size: int): - # Don't split kwargs that don't change - # Normally split torch tensors - # List split lists of strings - no_split = ['mode', 'generation_length', 'pass_at_k', 'generation_kwargs'] - normal_split = ['input_ids', 'attention_mask'] - list_split = [ - 'labels', 'tests', 'canonical_solutions', 'entry_points', 'test_inputs', 'test_outputs', 'prompts', - 'languages' - ] - chunked = {} - for k, v in batch.items(): - if k in no_split: - # Defer broadcasting until we know num_chunks - pass - elif k in list_split: - chunked[k] = _split_list(v, microbatch_size) - elif k in normal_split: - chunked[k] = _default_split_batch(v, microbatch_size) - else: - raise ValueError(f'Unexpected key {k}') - num_chunks = len(chunked['input_ids']) - for k, v in batch.items(): - if isinstance(v, (int, float, str, bool, dict)): - chunked[k] = [v] * num_chunks + Returns: + None + """ + max_prompt_length = 0 + max_answer_length = 0 + for example in self.dataset: + assert isinstance(example, Dict) + unpadded_example = [token for token in example[self.context_key] if token != self.pad_tok_id] + max_prompt_length = max(max_prompt_length, len(unpadded_example)) - return [{k: v[idx] for k, v in chunked.items()} for idx in range(num_chunks)] + tokenized_answer = self.tokenizer(example['canonical_solution'], add_special_tokens=False)['input_ids'] + assert isinstance(tokenized_answer, list) + len_tokenized_answer = len(tokenized_answer) + max_answer_length = max(max_answer_length, len_tokenized_answer) + + self.max_prompt_length = max_prompt_length + self.max_answer_length = max_answer_length + _MAX_ANSWER_BUFFER_LENGTH + + def _trim_padding(self, example: Dict): + """ + Adjusts padding to the maximum prompt length rather than max_seq_len. + Needs to be done after the dataset has been processed because we don't know the maximum + prompt length until after we've tokenized it. + + Returns: + dataset: A HuggingFace Dataset with different padding lengths for example[self.context_key] + """ + # Remove padding tokens applied during tokenization + unpadded_prompt = [token for token in example[self.context_key] if token != self.pad_tok_id] + # Reapply padding only to max_prompt_length + full_prompt = _trim_context(unpadded_prompt, [], self.max_prompt_length) + padded_context = _make_padded_input(full_prompt, [], self.max_prompt_length, self.pad_tok_id, self.padding_side) + + example[self.context_key] = padded_context + return example + + def tokenize_example(self, prompt_and_fewshot: str, ctxt: str, example: Dict) -> Dict[str, Any]: + """ + Adds extra code task details to the example dictionary. + See InContextLearningDataset for more details + """ + tokenized_example = super().tokenize_example(prompt_and_fewshot, ctxt, example) + tokenized_example['prompt_text'] = example['prompt'] + tokenized_example['task_id'] = example['task_id'] + tokenized_example['canonical_solution'] = example['canonical_solution'] + tokenized_example['test'] = example['test'] + tokenized_example['entry_point'] = example['entry_point'] + tokenized_example['test_inputs'] = example['test_inputs'] + tokenized_example['test_outputs'] = example['test_outputs'] + tokenized_example['language'] = example['language'] + return tokenized_example def build_icl_dataloader( icl_task_type: str, dataset_uri: str, - tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast], + tokenizer: transformers.PreTrainedTokenizerBase, batch_size: int, max_seq_len: int, pad_tok_id: int, @@ -1193,81 +1362,120 @@ def build_icl_dataloader( prompt_string: str, # e.g. 'translate english to french:' example_delimiter: str, # e.g. '\n' continuation_delimiter: str, # e.g. '' + hf_loading_vars: Dict, + hf_parsing_map: Dict, destination_path: str, - question_prelimiter: str = '', # e.g. 'Question: ' - cot_delimiter: str = '', - fewshot_random_seed: int = 1234, - pass_at_k: int = 1, - generations_per_sample: int = 1, + prelimiter: str, # e.g. 'Question: ' + cot_delimiter: str, # e.g. ' ### ' + fewshot_random_seed: int, + pass_at_k: int, + generations_per_sample: int, + generation_kwargs: Dict, early_stopping_criteria: Optional[List[str]] = None, do_normalization: bool = True) -> DataSpec: + """ + Factory method that builds the specific dataset for the specified icl_task_type. + See documentation for `get_icl_task_dataloader` for arugment documentation. + + When writing a dataset for a new task, here you will need to: + 1. add the dataset to the factory and choose an appropriate string + 2. set the batch size for that task (see InContextLearningMultipleChoiceTaskDataset for why + this might be different) + 3. set the `split_batch` funciton if necessary + """ if icl_task_type == 'multiple_choice': - dataset = InContextLearningMultipleChoiceTaskDataset(dataset_uri, - tokenizer, - max_seq_len, - pad_tok_id, - num_fewshot, - prompt_string, - example_delimiter, - continuation_delimiter, - destination_path=destination_path, - fewshot_random_seed=fewshot_random_seed) + dataset = InContextLearningMultipleChoiceTaskDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + fewshot_random_seed=fewshot_random_seed, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=generation_kwargs, + ) batch_size = max(dataset.num_choices, batch_size) effective_batchsize = batch_size // dataset.num_choices elif icl_task_type == 'schema': - dataset = InContextLearningSchemaTaskDataset(dataset_uri, - tokenizer, - max_seq_len, - pad_tok_id, - num_fewshot, - prompt_string, - example_delimiter, - continuation_delimiter, - destination_path=destination_path, - fewshot_random_seed=fewshot_random_seed) + dataset = InContextLearningSchemaTaskDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + fewshot_random_seed=fewshot_random_seed, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=generation_kwargs, + ) batch_size = max(dataset.num_choices, batch_size) effective_batchsize = batch_size // dataset.num_choices elif icl_task_type == 'language_modeling': - dataset = InContextLearningLMTaskDataset(dataset_uri, - tokenizer, - max_seq_len, - pad_tok_id, - num_fewshot, - prompt_string, - example_delimiter, - continuation_delimiter, - destination_path=destination_path, - fewshot_random_seed=fewshot_random_seed) + dataset = InContextLearningLMTaskDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + fewshot_random_seed=fewshot_random_seed, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=generation_kwargs, + ) effective_batchsize = batch_size elif icl_task_type == 'question_answering': - dataset = InContextLearningQATaskDataset(dataset_uri, - tokenizer, - max_seq_len, - pad_tok_id, - num_fewshot, - prompt_string, - example_delimiter, - continuation_delimiter, - destination_path=destination_path, - question_prelimiter=question_prelimiter, - fewshot_random_seed=fewshot_random_seed, - cot_delimiter=cot_delimiter, - early_stopping_criteria=early_stopping_criteria, - do_normalization=do_normalization) + dataset = InContextLearningQATaskDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=prelimiter, + fewshot_random_seed=fewshot_random_seed, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + cot_delimiter=cot_delimiter, + early_stopping_criteria=early_stopping_criteria, + do_normalization=do_normalization, + generation_kwargs=generation_kwargs, + ) effective_batchsize = batch_size elif icl_task_type == 'code_evaluation': - dataset = InContextLearningCodeEvalDataset(dataset_uri, - tokenizer, - max_seq_len, - pad_tok_id, - num_fewshot, - prompt_string, - example_delimiter, - destination_path=destination_path, - code_prelimiter=question_prelimiter, - fewshot_random_seed=fewshot_random_seed, - pass_at_k=pass_at_k, - generations_per_sample=generations_per_sample) + dataset = InContextLearningCodeEvalDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=prelimiter, + fewshot_random_seed=fewshot_random_seed, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + pass_at_k=pass_at_k, + generations_per_sample=generations_per_sample, + generation_kwargs=generation_kwargs, + ) effective_batchsize = batch_size else: raise Exception(f'Unrecognized ICL task type: {icl_task_type}') @@ -1277,7 +1485,12 @@ def build_icl_dataloader( split_batch = None if isinstance( dataset, - (InContextLearningMultipleChoiceTaskDataset, InContextLearningQATaskDataset, InContextLearningCodeEvalDataset)): + ( + InContextLearningMultipleChoiceTaskDataset, + InContextLearningQATaskDataset, + InContextLearningCodeEvalDataset, + ), + ): split_batch = dataset.split_batch return DataSpec( @@ -1293,8 +1506,10 @@ def build_icl_dataloader( ) -def partition_dataset_by_category(dataset_uri: str, destination_path: str) -> Dict[str, str]: - """If has_categories is enabled, we partition the dataset into a separate dataset for each category value in the data and write each partition to a local file. +def partition_dataset_by_category(dataset_uri: str, destination_path: str, hf_loading_vars: Dict, + hf_parsing_map: Dict) -> Dict[str, str]: + """ + If has_categories is enabled, we partition the dataset into a separate dataset for each category value in the data and write each partition to a local file. Args: dataset_uri (str): Location of dataset. @@ -1307,19 +1522,37 @@ def partition_dataset_by_category(dataset_uri: str, destination_path: str) -> Di Dict[str, str]: Mapping of category names to partitioned dataset local files names. """ try: - from datasets import load_dataset # pyright: ignore [reportGeneralTypeIssues] + from datasets import Dataset as HFDataset # pyright: ignore[reportGeneralTypeIssues] + from datasets import IterableDataset, load_dataset # pyright: ignore[reportGeneralTypeIssues] except ImportError as e: - raise MissingConditionalImportError(extra_deps_group='nlp', - conda_package='datasets', - conda_channel='conda-forge') from e - with dist.local_rank_zero_download_and_wait(destination_path): - if dist.get_local_rank() == 0: - get_file(dataset_uri, destination_path, overwrite=True) - dataset = load_dataset('json', data_files=destination_path, split='train', streaming=False) - if 'category' not in dataset.features.keys(): # type: ignore - raise Exception((f"Attempted to partition dataset by `category` but it doesn't have " - f'a `category` key. Got keys: {str(list(dataset.features.keys()))}')) # type: ignore - categories = sorted(set(dataset['category'])) # pyright: ignore[reportGeneralTypeIssues] + raise MissingConditionalImportError( + extra_deps_group='nlp', + conda_package='datasets', + conda_channel='conda-forge', + ) from e + if dataset_uri.startswith('hf://'): + dataset_uri = dataset_uri.replace('hf://', '') + dataset = load_dataset(dataset_uri, **hf_loading_vars) + assert isinstance(dataset, HFDataset) or isinstance(dataset, IterableDataset) + if hf_parsing_map: + dataset_parsing_func = lambda example: { + k: ' '.join([str(example[col]) for col in v]) for k, v in hf_parsing_map.items() + } + assert hasattr(dataset, 'column_names') + dataset = dataset.map(dataset_parsing_func, remove_columns=dataset.column_names) + else: + with dist.local_rank_zero_download_and_wait(destination_path): + if dist.get_local_rank() == 0: + get_file(dataset_uri, destination_path, overwrite=True) + dataset = load_dataset('json', data_files=destination_path, split='train', streaming=False) + assert isinstance(dataset, HFDataset) or isinstance(dataset, IterableDataset) + assert hasattr(dataset, 'features') + assert dataset.features is not None + if 'category' not in dataset.features.keys(): + raise Exception(f"""Attempted to partition dataset by `category` \ + but it doesn't have a `category` key. \ + Got keys: {str(list(dataset.features.keys()))}""") + categories = sorted(set(dataset['category'])) # pyright: ignore[reportIndexIssue, reportGeneralTypeIssues] output_files = {} for cat in categories: path = destination_path.split('/') @@ -1327,7 +1560,9 @@ def partition_dataset_by_category(dataset_uri: str, destination_path: str) -> Di tmp_path_to_broadcast = str(os.path.abspath(cat_dest)) gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) if dist.get_local_rank() == 0: - subset = [l for l in dataset if l['category'] == cat] # pyright: ignore[reportGeneralTypeIssues] + subset = [ + l for l in dataset if l['category'] == cat # pyright: ignore[reportGeneralTypeIssues] + ] # pyright: ignore[reportArgumentType, reportCallIssue] with open(gathered_paths[0], 'w', encoding='utf8') as f: for l in subset: f.write(json.dumps(l, ensure_ascii=False) + '\n') @@ -1353,95 +1588,149 @@ def get_icl_task_dataloader( generations_per_sample: int = 1, cot_delimiter: str = '', has_categories: bool = False, + hf_loading_vars: Optional[Dict] = None, + hf_parsing_map: Optional[Dict] = None, + generation_kwargs: Optional[Dict] = None, early_stopping_criteria: Optional[List[str]] = None, do_normalization: bool = True) -> Union[DataSpec, Dict[str, DataSpec]]: """This constructs a dataloader (or dataloaders if has_categories is True) capable of evaluating LLMs on in-context learning language modeling tasks, for example LAMBADA. An example usage is below: - >>> dl = get_icl_task_dataloader( - ... 'language_modeling', - ... dataset_uri, - ... tokenizer, - ... batch_size=2, - ... max_seq_len=2048, - ... pad_tok_id=tokenizer.pad_token_id, - ... num_fewshot=10, - ... prompt_string='translate english to french', - ... example_delimiter='\n', - ... continuation_delimiter='' - ) - >>> eval_evaluator = Evaluator( - ... label="lambada", - ... dataloader=dl, - ... metric_names=['InContextLearningLMAccuracy'] - ... ) - >>> trainer = Trainer( - ... model=model, - ... train_dataloader=train_dataloader, - ... eval_dataloader=eval_evaluator, - ... optimizers=optimizer, - ... max_duration="1ep", - ... ) + .. testsetup:: + + import transformers + from composer.models import HuggingFaceModel + from composer.trainer import Trainer + dataset_uri = "/tmp/dataset_uri.jsonl" + dataset = RandomTextClassificationDataset(size=16, use_keys=True) + train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=8) + hf_model, tokenizer = HuggingFaceModel.hf_from_composer_checkpoint('composer-hf-checkpoint.pt') + # At this point, hf_model is randomly initialized + composer_model = HuggingFaceModel(hf_model, hf_tokenizer) + + Example: + + .. testcode:: + + + dl = get_icl_task_dataloader( + 'language_modeling', + dataset_uri, + tokenizer, + batch_size=2, + max_seq_len=2048, + pad_tok_id=tokenizer.pad_token_id, + num_fewshot=10, + prompt_string='translate english to french', + example_delimiter='\\n', + continuation_delimiter='' + ) + eval_evaluator = Evaluator( + label="lambada", + dataloader=dl, + metric_names=['InContextLearningLMAccuracy'] + ) + trainer = Trainer( + model=model, + train_dataloader=train_dataloader, + eval_dataloader=eval_evaluator, + optimizers=optimizer, + max_duration="1ep", + ) Args: - dataset_uri (str): Either a local path, or a remote path beginning with ``s3://``, or another backend - supported by :meth:`composer.utils.maybe_create_object_store_from_uri`. - tokenizer (Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast]): The tokenizer used to transform data into batches + icl_task_type (str): Name of icl_task type. One of ['multiple_choice', 'schema', 'language_modeling', 'question_answering', 'code_evaluation'] + dataset_uri (str): A local path, a remote path beginning with ``s3://`` or another backend, or a HuggingFace dataset uri prepended with ``hf://``. + Alternate backends must be supported by :meth:`composer.utils.maybe_create_object_store_from_uri`. + A local dataset must consist of rows of JSON data points with task dependant fields. + The default keys expected are "context" and "answer". + tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer used to map between strings and token ids. batch_size (int): Size of a batch used for eval - max_seq_len (int): The sequence length expected by the model - pad_tok_id (int): The special token reserved for padding the ends of batches - num_fewshot (int): The number of complete fewshot examples to pad each test example with - prompt_string (str): Prompt string to put once before all fewshot examples/test examples (e.g. 'translate english to french') - example_delimiter (str): Separator that goes between individual examples (e.g. '\n') - continuation_delimiter: (str): Separator that goes between context and continuation in each example (e.g. '->') - destination_path: (str): This is the local file where remote datasets will be saved. - question_prelimiter: (str): For QA tasks, this will be prepended to each question. + max_seq_len (int): The maximum sequence length supported by the model. + pad_tok_id (int): The special token used for padding batches. + num_fewshot (int): The number of complete fewshot examples to prepend before each test example. These are not identical across examples. + prompt_string (str, default = ''): Prompt string to put once before all fewshot examples/test examples (e.g. 'Translate english to french.'). + example_delimiter (str, default = '\\n'): Separator inserted before (context, answer) pairs (e.g. '\\n') for fewshot sampling and prompting. + continuation_delimiter: (str, default = ' '): Separator inserted between context and answer in each example (e.g. '\\nA: '). + destination_path: (str, default = ''): This is the local file where remote datasets will be saved. + question_prelimiter: (str, default = ''): Text to be prepended before each context, including few shot examples (e.g. "Question: "). + fewshot_random_seed (int, default = 1234): Random seed to use for fewshot sampling + pass_at_k (int): k for how many chances the model gets to write passing code. + generations_per_sample (int): How many outputs to generate per prompt. Passed in generation_kwargs under "num_return_sequences" and overwritten by generation_kwargs dict. + cot_delimiter (str): Delimiter to place between chain of thoughts and continuations. has_categories: (bool): If ``True``, we will search the dataset file for a category key, and partition the dataset into a separate dataloader for each category occurring in the data. + hf_loading_vars (Dict, default = None): A dictionary containing keyword arguments to be passed into `load_dataset` if dataset is being pulled from HF. + hf_parsing_map (Dict, default = None): A dictionary containing a mapping from HF columns to ICL dataset keys. The dictionary should be formatted {icl_key:[hf_key1, hf_key1]}. + Column contents will be concatenated with ' ' seperating them. If not included, will load the columns already present in the HF dataset. + generation_kwargs (Dict, default = None): A dictionary containing keyword arguments to be passed along to the model's generate function. Overwrites any previously specified generation + keyword args in this fucntion (see https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig + for more details) + early_stopping (List, default = None): A list of strings that, when found in a model's output, will be treated as a stopping criteria at metric computation time. + Used in QA tasks with CoT + do_normalization (bool, default = True): Whether or not to normalize the outputs and labels in InContextLearningQAAccuracy. Only used in QA tasks. Returns: DataLoader: A dataloader used for performing in-context learning evaluation on the dataset provided. """ + if hf_loading_vars is None: + hf_loading_vars = {} + if hf_parsing_map is None: + hf_parsing_map = {} + if generation_kwargs is None: + generation_kwargs = {} + if early_stopping_criteria is None: + early_stopping_criteria = [] if has_categories: result_dls = {} - output_files = partition_dataset_by_category(dataset_uri, destination_path) + output_files = partition_dataset_by_category(dataset_uri, destination_path, hf_loading_vars, hf_parsing_map) categories = sorted(output_files.keys()) for category in categories: partition_uri = output_files[category] - result_dls[category] = build_icl_dataloader(icl_task_type, - partition_uri, - tokenizer, - batch_size, - max_seq_len, - pad_tok_id, - num_fewshot, - prompt_string, - example_delimiter, - continuation_delimiter, - partition_uri + '_tmp', - question_prelimiter, - cot_delimiter, - fewshot_random_seed, - pass_at_k, - generations_per_sample, - early_stopping_criteria=early_stopping_criteria, - do_normalization=do_normalization) + result_dls[category] = build_icl_dataloader( + icl_task_type=icl_task_type, + dataset_uri=partition_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=partition_uri + '_tmp', + prelimiter=question_prelimiter, + cot_delimiter=cot_delimiter, + fewshot_random_seed=fewshot_random_seed, + pass_at_k=pass_at_k, + generations_per_sample=generations_per_sample, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=generation_kwargs, + early_stopping_criteria=early_stopping_criteria, + do_normalization=do_normalization, + ) return result_dls else: - return build_icl_dataloader(icl_task_type, - dataset_uri, - tokenizer, - batch_size, - max_seq_len, - pad_tok_id, - num_fewshot, - prompt_string, - example_delimiter, - continuation_delimiter, - destination_path, - question_prelimiter, - cot_delimiter, - fewshot_random_seed, - pass_at_k, - generations_per_sample, - early_stopping_criteria=early_stopping_criteria, - do_normalization=do_normalization) + return build_icl_dataloader( + icl_task_type=icl_task_type, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=question_prelimiter, + cot_delimiter=cot_delimiter, + fewshot_random_seed=fewshot_random_seed, + pass_at_k=pass_at_k, + generations_per_sample=generations_per_sample, + generation_kwargs=generation_kwargs, + early_stopping_criteria=early_stopping_criteria, + do_normalization=do_normalization, + ) diff --git a/composer/datasets/utils.py b/composer/datasets/utils.py index 431a860900..b627ef8596 100644 --- a/composer/datasets/utils.py +++ b/composer/datasets/utils.py @@ -179,7 +179,7 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria): def __init__( self, stop_sequence: str, - tokenizer: transformers.PreTrainedTokenizer, + tokenizer: transformers.PreTrainedTokenizerBase, batch_size: int, ) -> None: self.done_tracker = [False] * batch_size @@ -196,7 +196,7 @@ def __init__( self.stop_sequence_id_len = len(self.stop_sequence_ids) + 2 self.tokenizer = tokenizer - def __call__(self, input_ids, scores: Optional[torch.FloatTensor] = None, **kwargs) -> bool: + def __call__(self, input_ids: torch.Tensor, scores: Optional[torch.FloatTensor] = None, **kwargs) -> bool: # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence lookback_ids_batch = input_ids[:, :][:, -self.stop_sequence_id_len:] @@ -213,7 +213,7 @@ def __call__(self, input_ids, scores: Optional[torch.FloatTensor] = None, **kwar return False not in self.done_tracker def stop_sequences_criteria( - tokenizer: transformers.PreTrainedTokenizer, + tokenizer: transformers.PreTrainedTokenizerBase, stop_sequences: List[str], batch_size: int, ) -> transformers.StoppingCriteriaList: diff --git a/tests/datasets/test_in_context_learning_datasets.py b/tests/datasets/test_in_context_learning_datasets.py index ec7df306d6..9a98e2b174 100644 --- a/tests/datasets/test_in_context_learning_datasets.py +++ b/tests/datasets/test_in_context_learning_datasets.py @@ -9,15 +9,27 @@ import pytest import torch -import transformers from torch.utils.data import DataLoader -from transformers import AutoTokenizer from composer import Evaluator from composer.core import DataSpec -from composer.datasets.in_context_learning_evaluation import (InContextLearningCodeEvalDataset, - _get_fewshot_sample_idxs, _make_padded_input, - get_icl_task_dataloader) + +# isort: off +from composer.datasets.in_context_learning_evaluation import ( + InContextLearningCodeEvalDataset, + InContextLearningDataset, + InContextLearningMultipleChoiceTaskDataset, + InContextLearningQATaskDataset, + InContextLearningSchemaTaskDataset, + _get_continuation_span, + _get_fewshot_sample_idxs, + _make_padded_input, + _tokenizer_needs_prefix_space, + _trim_context, + get_icl_task_dataloader, + strip_data, +) +# isort: on from composer.datasets.utils import MultiTokenEOSCriteria from composer.loggers import InMemoryLogger from composer.metrics import (InContextLearningCodeEvalAccuracy, InContextLearningLMAccuracy, @@ -28,19 +40,122 @@ from tests.common import device, world_size +def test_strip_data(): + data_to_strip = {'strip_data': ' boo! \n', 'has_space': ' wa hoo!', 'end_space': 'yoohoo! '} + stripped_data = strip_data(data_to_strip) + for k, v in stripped_data.items(): + assert k in data_to_strip + assert not v[0].isspace() + assert not v[-1].isspace() + + +@pytest.mark.skip(reason="Currently don't have a tokenizer that satisfies this test") +def test_tokenizer_needs_prefix_space_when_space_not_needed(tiny_gpt2_tokenizer): + assert not _tokenizer_needs_prefix_space(tiny_gpt2_tokenizer) + + +def test_tokenizer_needs_prefix_space_when_space_needed(): + transformers = pytest.importorskip('transformers') + tokenizer = transformers.AutoTokenizer.from_pretrained('facebook/opt-125m', + use_fast=False) # type: ignore reportUnboundVariable + assert _tokenizer_needs_prefix_space(tokenizer) + + +def test_trim_context(): + context = [0] * 99 + [1] * 2037 + continuation = [2] * 10 + max_seq_len = 2048 + trimmed_context = _trim_context(context, continuation, max_seq_len=max_seq_len) + assert len(trimmed_context) == 2038 + assert trimmed_context[0] == 0 + assert trimmed_context[1] == 1 + + +def test_trim_context_no_continuation(): + context = [0] * 2048 + max_seq_len = 2048 + trimmed_context = _trim_context(context, [], max_seq_len=max_seq_len) + assert len(trimmed_context) == 2048 + context = [0] * 3000 + [1] + max_seq_len = 2048 + trimmed_context = _trim_context(context, [], max_seq_len=max_seq_len) + assert len(trimmed_context) == 2048 + assert trimmed_context[-1] == 1 + + +def test_get_continuation_span(): + context = [0] * 200 + continuation = [1] * 3 + cont_span = _get_continuation_span(context, continuation) + assert torch.all(torch.eq(cont_span, torch.tensor([200, 201, 202]))) + continuation = [1] + cont_span = _get_continuation_span(context, continuation) + assert torch.all(torch.eq(cont_span, torch.tensor([200]))) + + +@pytest.mark.parametrize('padding_side', ['left', 'right', 'middle']) +def test_make_padding(tiny_gpt2_tokenizer, padding_side): + context = tiny_gpt2_tokenizer(' cat' * 2000)['input_ids'] + padding_id = tiny_gpt2_tokenizer.eos_token_id + + error_context = contextlib.nullcontext() if padding_side in {'left', 'right'} else pytest.raises(ValueError) + + with error_context: + input_ids = _make_padded_input(context, [], 2048, padding_id, padding_side=padding_side) + + if padding_side == 'left': + assert input_ids[0] == tiny_gpt2_tokenizer.eos_token_id + assert input_ids[48:].tolist() == context + elif padding_side == 'right': + assert input_ids[-1] == tiny_gpt2_tokenizer.eos_token_id + assert input_ids[:-48].tolist() == context + + +def test_batch_padding_logic_no_padding(tiny_gpt2_tokenizer): + continuation = tiny_gpt2_tokenizer(' dog' * 2000)['input_ids'] + context = tiny_gpt2_tokenizer(' cat' * 2000)['input_ids'] + max_seq_len = 2048 + trimmed_context = _trim_context(context, continuation, max_seq_len) + continuation_spans = _get_continuation_span(trimmed_context, continuation) + padded_input = _make_padded_input(trimmed_context, + continuation, + max_seq_len, + tiny_gpt2_tokenizer.pad_token_id, + padding_side='right') + assert continuation_spans[0] == 48 and continuation_spans[-1] == 2047 + assert len(padded_input) == 2048 + assert tiny_gpt2_tokenizer.pad_token_id not in padded_input + + +def test_batch_padding_logic_with_padding(tiny_gpt2_tokenizer): + continuation = tiny_gpt2_tokenizer(' dog' * 200)['input_ids'] + context = tiny_gpt2_tokenizer(' cat' * 200)['input_ids'] + max_seq_len = 2048 + trimmed_context = _trim_context(context, continuation, max_seq_len) + continuation_spans = _get_continuation_span(trimmed_context, continuation) + padded_input = _make_padded_input(trimmed_context, + continuation, + max_seq_len, + tiny_gpt2_tokenizer.pad_token_id, + padding_side='right') + assert continuation_spans[0] == 200 and continuation_spans[-1] == 399 + assert len(padded_input) == 2048 + assert padded_input[-1] == tiny_gpt2_tokenizer.pad_token_id + + def test_fewshot_sample_idxs(): rng = random.Random(1234) - fewshot_idxs = _get_fewshot_sample_idxs(dataset_size=5, num_fewshot=4, sample_idx=4, rng=rng) + fewshot_idxs = _get_fewshot_sample_idxs(dataset_size=5, num_fewshot=4, example_idx=4, rng=rng) assert fewshot_idxs == {0, 1, 2, 3} - fewshot_idxs = _get_fewshot_sample_idxs(dataset_size=5, num_fewshot=5, sample_idx=4, rng=rng) + fewshot_idxs = _get_fewshot_sample_idxs(dataset_size=5, num_fewshot=5, example_idx=4, rng=rng) assert fewshot_idxs == {0, 1, 2, 3} - fewshot_idxs = _get_fewshot_sample_idxs(dataset_size=5, num_fewshot=500, sample_idx=4, rng=rng) + fewshot_idxs = _get_fewshot_sample_idxs(dataset_size=5, num_fewshot=500, example_idx=4, rng=rng) assert fewshot_idxs == {0, 1, 2, 3} - fewshot_idxs = _get_fewshot_sample_idxs(dataset_size=10, num_fewshot=7, sample_idx=4, rng=rng) + fewshot_idxs = _get_fewshot_sample_idxs(dataset_size=10, num_fewshot=7, example_idx=4, rng=rng) assert len(fewshot_idxs) == 7 and 4 not in fewshot_idxs @@ -67,6 +182,37 @@ def test_fewshot_sample_idxs_randomness(): assert rng_1_sample_2 != rng_3_sample_2 +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning') +def test_update_generation_kwargs(tiny_gpt2_tokenizer, tmp_path): + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + hf_loading_vars = { + 'split': 'test', + 'name': 'invoker', + } + hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']} + gen_kwargs = {'test_arg1': 1, 'test_arg2': 2} + + dl = InContextLearningDataset(dataset_uri='hf://mosaicml/test_dataset', + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Orbs: ', + continuation_delimiter='\nSpell:', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=gen_kwargs) + assert dl.base_batch['generation_kwargs'] == {'test_arg1': 1, 'test_arg2': 2} + + def test_stop_sequences_criteria(tiny_gpt2_tokenizer): pytest.importorskip('transformers') eos_criteria = MultiTokenEOSCriteria('\n\n', tiny_gpt2_tokenizer, 2) @@ -83,30 +229,554 @@ def test_stop_sequences_criteria(tiny_gpt2_tokenizer): assert eos_criteria(input_ids, None) -def test_batch_padding_logic(tiny_gpt2_tokenizer): - continuation = tiny_gpt2_tokenizer(' dog' * 2000)['input_ids'] - context = tiny_gpt2_tokenizer(' cat' * 2000)['input_ids'] - _, continuation_spans = _make_padded_input(context, continuation, 2048, tiny_gpt2_tokenizer.eos_token_id) - # the context (of len 2000) gets clipped to len 48 so that the whole continuation can fit - assert continuation_spans[0] == 48 and continuation_spans[-1] == 2047 +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning') +def test_update_generation_kwargs_no_kwargs(tiny_gpt2_tokenizer, tmp_path): + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + hf_loading_vars = { + 'split': 'test', + 'name': 'invoker', + } + hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']} + + dl = InContextLearningDataset(dataset_uri='hf://mosaicml/test_dataset', + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Orbs: ', + continuation_delimiter='\nSpell:', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map) + assert not dl.base_batch['generation_kwargs'] -@pytest.mark.parametrize('padding_side', ['left', 'right', 'middle']) -def test_make_padding(tiny_gpt2_tokenizer, padding_side): - context = tiny_gpt2_tokenizer(' cat' * 2000)['input_ids'] - padding_id = tiny_gpt2_tokenizer.eos_token_id +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning') +def test_construct_context(tiny_gpt2_tokenizer, tmp_path): + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + hf_loading_vars = { + 'split': 'test', + 'name': 'invoker', + } + hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']} - error_context = contextlib.nullcontext() if padding_side in {'left', 'right'} else pytest.raises(ValueError) + dl = InContextLearningDataset(dataset_uri='hf://mosaicml/test_dataset', + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Orbs: ', + continuation_delimiter='\nSpell: ', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map) + constructed_context = dl.construct_context({'context': 'quas quas exort', 'answer': 'ice wall'}) + assert constructed_context == 'Orbs: quas quas exort\nSpell: ' + constructed_context = dl.construct_context({'context': 'quas quas exort', 'answer': 'ice wall'}, add_answer=True) + assert constructed_context == 'Orbs: quas quas exort\nSpell: ice wall' + constructed_context = dl.construct_context({ + 'context': 'quas quas exort', + 'answer': 'ice wall' + }, + preceding_text='The harsh White Waste beckons!', + add_answer=True) + assert constructed_context == '\nOrbs: quas quas exort\nSpell: ice wall' + + +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning') +def test_get_answer_from_example(tiny_gpt2_tokenizer, tmp_path): + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + hf_loading_vars = { + 'split': 'test', + 'name': 'invoker', + } + hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']} - with error_context: - input_ids, _ = _make_padded_input(context, [], 2048, padding_id, padding_side=padding_side) + dl = InContextLearningDataset(dataset_uri='hf://mosaicml/test_dataset', + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Orbs: ', + continuation_delimiter='\nSpell:', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map) + answer = dl.get_answer_from_example({'context': 'wex exort exort', 'answer': 'alacrity'}) + assert answer == ' alacrity' + + +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning') +def test_fix_eos_on_preamble(tmp_path): + transformers = pytest.importorskip('transformers') + tokenizer = transformers.AutoTokenizer.from_pretrained('facebook/opt-125m', + use_fast=False) # type: ignore reportUnboundVariable + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + hf_loading_vars = { + 'split': 'test', + 'name': 'invoker', + } + hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']} - if padding_side == 'left': - assert input_ids[0] == tiny_gpt2_tokenizer.eos_token_id - assert input_ids[48:].tolist() == context - elif padding_side == 'right': - assert input_ids[-1] == tiny_gpt2_tokenizer.eos_token_id - assert input_ids[:-48].tolist() == context + dl = InContextLearningDataset(dataset_uri='hf://mosaicml/test_dataset', + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Orbs: ', + continuation_delimiter='\nSpell:', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map) + preamble = 'blah blah blah.' + tokenized_preamble = tokenizer.encode(preamble) + tokenized_preamble += [tokenizer.eos_token_id] + fixed_preamble = dl._fix_eos_on_preamble(tokenized_preamble) + assert tokenized_preamble[:-1] == fixed_preamble + assert fixed_preamble[-1] != tokenizer.eos_token_id + + +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning') +def test_tokenize_example_with_tokenize_labels(tiny_gpt2_tokenizer, tmp_path): + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + hf_loading_vars = { + 'split': 'test', + 'name': 'invoker', + } + hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']} + + dl = InContextLearningDataset(dataset_uri='hf://mosaicml/test_dataset', + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Orbs: ', + continuation_delimiter='\nSpell: ', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + tokenize_labels=True) + tokenized_example = dl.tokenize_example('What spell does this invoke? ', 'exort exort wex\nSpell: ', + {'answer': ' Meatball'}) + tokenized_input = [2061, 4822, 857, 428, 26342, 30, 220, 1069, 419, 409, 419, 356, 87, 198, 31221, 25, 19145, 1894] + assert tokenized_example['context'][:len(tokenized_input)].tolist() == tokenized_input + assert tokenized_example['context'][-1] == tokenizer.eos_token_id + assert type(tokenized_example['answer'][0]) == int + assert len(tokenized_example['context']) == seqlen + assert 'continuation_indices' in tokenized_example + + +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning') +def test_tokenize_example_with_no_tokenize_labels(tiny_gpt2_tokenizer, tmp_path): + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + hf_loading_vars = { + 'split': 'test', + 'name': 'invoker', + } + hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']} + + dl = InContextLearningDataset(dataset_uri='hf://mosaicml/test_dataset', + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Orbs: ', + continuation_delimiter='\nSpell: ', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + tokenize_labels=False) + tokenized_example = dl.tokenize_example('What spell does this invoke? ', 'exort exort wex\nSpell: ', + {'answer': ' Meatball'}) + tokenized_input = [2061, 4822, 857, 428, 26342, 30, 220, 1069, 419, 409, 419, 356, 87, 198, 31221, 25] + assert tokenized_example['context'][:len(tokenized_input)].tolist() == tokenized_input + assert tokenized_example['context'][-1] == tokenizer.eos_token_id + assert len(tokenized_example['context']) == seqlen + assert type(tokenized_example['answer']) == str + + +def test_qa_set_cot_no_cot(tmp_path): + pytest.importorskip('datasets') + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/triviaqa_small.jsonl' + transformers = pytest.importorskip('transformers') + tokenizer = transformers.AutoTokenizer.from_pretrained('facebook/opt-125m') # type: ignore reportUnboundVariable + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = InContextLearningQATaskDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + fewshot_random_seed=1234, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + assert not dl.has_cot + + +def test_qa_set_cot_has_cot(tmp_path): + pytest.importorskip('datasets') + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/gsm8k_small.jsonl' + transformers = pytest.importorskip('transformers') + tokenizer = transformers.AutoTokenizer.from_pretrained('facebook/opt-125m') # type: ignore reportUnboundVariable + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = InContextLearningQATaskDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + fewshot_random_seed=1234, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + assert dl.has_cot + + +def test_qa_get_max_answer_length(tiny_gpt2_tokenizer, tmp_path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/triviaqa_small.jsonl' + tokenizer = tiny_gpt2_tokenizer + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = InContextLearningQATaskDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + fewshot_random_seed=1234, + prompt_string='', + example_delimiter='', + continuation_delimiter='', + cot_delimiter='', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + # empirical number from the small test dataset + assert dl.max_answer_length == 7 + + +def test_qa_get_answer_from_example_with_no_cot(tmp_path, tiny_gpt2_tokenizer): + pytest.importorskip('datasets') + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/triviaqa_small.jsonl' + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = InContextLearningQATaskDataset( + dataset_uri=dataset_uri, + tokenizer=tiny_gpt2_tokenizer, + max_seq_len=1024, + pad_tok_id=tiny_gpt2_tokenizer.eos_token_id, + num_fewshot=0, + fewshot_random_seed=1234, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + cot_delimiter=' ### ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + answer = dl.get_answer_from_example({ + 'context': 'empty', + 'answer': 'this is the correct answer', + 'chain_of_thought': "Let's think step by step. " + }) + assert answer == 'this is the correct answer' + + +def test_qa_get_answer_from_example_with_cot(tmp_path, tiny_gpt2_tokenizer): + pytest.importorskip('datasets') + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/triviaqa_small.jsonl' + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = InContextLearningQATaskDataset( + dataset_uri=dataset_uri, + tokenizer=tiny_gpt2_tokenizer, + max_seq_len=1024, + pad_tok_id=tiny_gpt2_tokenizer.eos_token_id, + num_fewshot=0, + fewshot_random_seed=1234, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + cot_delimiter=' ### ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + dl.has_cot = True + answer = dl.get_answer_from_example({ + 'context': 'empty', + 'answer': 'this is the correct answer', + 'chain_of_thought': "Let's think step by step. " + }) + assert answer == "Let's think step by step. ### this is the correct answer" + + +def test_qa_tokenize_example(tiny_gpt2_tokenizer, tmp_path): + pytest.importorskip('datasets') + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/triviaqa_small.jsonl' + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = InContextLearningQATaskDataset( + dataset_uri=dataset_uri, + tokenizer=tiny_gpt2_tokenizer, + max_seq_len=1024, + pad_tok_id=tiny_gpt2_tokenizer.eos_token_id, + num_fewshot=0, + fewshot_random_seed=1234, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + cot_delimiter=' ### ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + dl.has_cot = True + tokenized_example = dl.tokenize_example( + 'starting prompt', 'a context', { + 'context': 'empty', + 'answer': 'this is the correct answer', + 'aliases': ['this is the right answer', 'this is the best answer'], + 'chain_of_thought': "Let's think step by step. " + }) + assert 'aliases' in tokenized_example + assert tokenized_example['aliases'] == ['this is the right answer', 'this is the best answer'] + + +def test_code_adjust_padding(tiny_gpt2_tokenizer, tmp_path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/human_eval_small.jsonl' + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + gen_kwargs = {'temperature': .9, 'top_p': .95, 'num_beams': 9000} + + dl = InContextLearningCodeEvalDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Code start:', + continuation_delimiter='\nPlease code:', + destination_path=str(tmp_path / 'test_human_eval_small.jsonl'), + generation_kwargs=gen_kwargs, + generations_per_sample=10, + ) + + assert all(len(data['prompt']) == 148 for data in dl.dataset) # pyright: ignore [reportGeneralTypeIssues] + + +def test_code_update_gen_kwargs(tiny_gpt2_tokenizer, tmp_path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/human_eval_small.jsonl' + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + gen_kwargs = {'temperature': .9, 'top_p': .95, 'num_beams': 9000} + + dl = InContextLearningCodeEvalDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Code start:', + continuation_delimiter='\nPlease code:', + destination_path=str(tmp_path / 'test_human_eval_small.jsonl'), + generation_kwargs=gen_kwargs, + generations_per_sample=10, + ) + assert dl.base_batch['generation_kwargs']['num_beams'] == 9000 + assert dl.base_batch['generation_kwargs']['top_p'] == .95 + assert dl.base_batch['generation_kwargs']['temperature'] == .9 + assert dl.base_batch['generation_kwargs']['do_sample'] == True + + +def test_mc_tokenize_example(tiny_gpt2_tokenizer, tmp_path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/mmlu_small.jsonl' + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + seqlen = 2048 + dl = InContextLearningMultipleChoiceTaskDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + continuation_delimiter=' ### ', + destination_path=str(tmp_path / 'test_human_eval_small.jsonl'), + ) + example = { + 'context': "Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: ", + 'choices': ['A', 'B', 'C', 'D'], + 'gold': 2 + } + tokenized_example = dl.tokenize_example(prompt_and_fewshot='Answer the following: ', + ctxt=example['context'], + example=example) + unpadded_queries = [context[context != tokenizer.eos_token_id] for context in tokenized_example['query']] + untokenized_inputs = [tokenizer.decode(unpadded_input) for unpadded_input in unpadded_queries] + correct_output = [ + "Answer the following: Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: A", + "Answer the following: Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: B", + "Answer the following: Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: C", + "Answer the following: Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: D" + ] + assert untokenized_inputs == correct_output + + +def test_schema_construct_context(tiny_gpt2_tokenizer, tmp_path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/winograd_small.jsonl' + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + seqlen = 2048 + dl = InContextLearningSchemaTaskDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=' ### ', + destination_path=str(tmp_path / 'test_human_eval_small.jsonl'), + ) + example = {'context_options': ['cont one', 'cont two'], 'gold': 0, 'continuation': 'this is a continuation'} + constructed_context = dl.construct_context(example) + assert constructed_context == 'cont one ### this is a continuation' + constructed_context = dl.construct_context(example, preceding_text='text') + assert constructed_context == '\ncont one ### this is a continuation' + + +def test_schema_construct_multiple_contexts(tiny_gpt2_tokenizer, tmp_path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/winograd_small.jsonl' + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + seqlen = 2048 + dl = InContextLearningSchemaTaskDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + continuation_delimiter=' ### ', + destination_path=str(tmp_path / 'test_human_eval_small.jsonl'), + ) + example = {'context_options': ['cont one', 'cont two'], 'gold': 0, 'continuation': 'this is a continuation'} + constructed_contexts = dl._construct_multiple_contexts(example) + assert constructed_contexts == ['cont one', 'cont two'] + constructed_contexts = dl._construct_multiple_contexts(example, preceding_text='some text') + assert constructed_contexts == ['\ncont one ###', '\ncont two ###'] + + +def test_schema_tokenize_example(tiny_gpt2_tokenizer, tmp_path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/winograd_small.jsonl' + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + seqlen = 2048 + dl = InContextLearningSchemaTaskDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + continuation_delimiter=' ### ', + destination_path=str(tmp_path / 'test_human_eval_small.jsonl'), + ) + example = {'context_options': ['context one', 'context two'], 'gold': 0, 'continuation': 'this is a continuation'} + tokenized_example = dl.tokenize_example(prompt_and_fewshot='prompt ', + context_options=example['context_options'], + example=example) + assert all(tiny_gpt2_tokenizer.decode(cont) == ' this is a continuation' for cont in tokenized_example['answer']) + unpadded_inputs = [context[context != tokenizer.eos_token_id] for context in tokenized_example['context_options']] + untokenized_inputs = [tokenizer.decode(unpadded_input) for unpadded_input in unpadded_inputs] + assert untokenized_inputs == [ + 'prompt context one this is a continuation', 'prompt context two this is a continuation' + ] @pytest.mark.parametrize('dataset_uri', ['mmlu_small.jsonl']) @@ -120,9 +790,9 @@ def test_mc_task_dataloader_subcategories(dataset_uri, tiny_gpt2_tokenizer, tmp_ batch_size = 8 seqlen = 64 dls = get_icl_task_dataloader('multiple_choice', - dataset_uri, - tokenizer, - batch_size, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, max_seq_len=seqlen, pad_tok_id=tokenizer.eos_token_id, num_fewshot=2, @@ -164,9 +834,9 @@ def test_lm_task_dataloader_extra_space(dataset_uri, tiny_gpt2_tokenizer, tmp_pa batch_size = 2 seqlen = 64 dl = get_icl_task_dataloader('language_modeling', - dataset_uri, - tokenizer, - batch_size, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, max_seq_len=seqlen, pad_tok_id=tokenizer.eos_token_id, num_fewshot=10, @@ -205,9 +875,9 @@ def test_lm_task_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path): batch_size = 2 seqlen = 64 dl = get_icl_task_dataloader('language_modeling', - dataset_uri, - tokenizer, - batch_size, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, max_seq_len=seqlen, pad_tok_id=tokenizer.eos_token_id, num_fewshot=0, @@ -243,9 +913,9 @@ def test_schema_task_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path): batch_size = 2 seqlen = 64 dl = get_icl_task_dataloader('schema', - dataset_uri, - tokenizer, - batch_size, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, max_seq_len=seqlen, pad_tok_id=tokenizer.eos_token_id, num_fewshot=1, @@ -280,17 +950,19 @@ def test_schema_task_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path): @pytest.mark.parametrize('dataset_uri', ['winograd_small.jsonl']) def test_schema_task_dataloader_sentpiece_tokenizer(dataset_uri, tmp_path): pytest.importorskip('datasets') + transformers = pytest.importorskip('transformers') local_data = os.path.join(os.path.dirname(__file__), 'local_data') - - tokenizer = AutoTokenizer.from_pretrained('huggyllama/llama-7b', use_fast=False) + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'huggyllama/llama-7b', # type: ignore reportUnboundVariable + use_fast=False) dataset_uri = f'{local_data}/{dataset_uri}' batch_size = 2 seqlen = 64 dl = get_icl_task_dataloader('schema', - dataset_uri, - tokenizer, - batch_size, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, max_seq_len=seqlen, pad_tok_id=tokenizer.eos_token_id, num_fewshot=1, @@ -335,9 +1007,9 @@ def test_lm_task_dataloader_opt_tokenizer(tiny_opt_tokenizer, dataset_uri, num_f batch_size = 2 seqlen = 512 dl = get_icl_task_dataloader('language_modeling', - dataset_uri, - tokenizer, - batch_size, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, max_seq_len=seqlen, pad_tok_id=tokenizer.eos_token_id, num_fewshot=num_fewshot, @@ -377,9 +1049,9 @@ def test_mc_task_dataloader_opt_tokenizer(tiny_opt_tokenizer, dataset_uri, num_f batch_size = 4 seqlen = 64 dl = get_icl_task_dataloader('multiple_choice', - dataset_uri, - tokenizer, - batch_size, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, max_seq_len=seqlen, pad_tok_id=tokenizer.eos_token_id, num_fewshot=num_fewshot, @@ -427,9 +1099,9 @@ def test_mc_split_batch(tiny_opt_tokenizer, dataset_uri, num_fewshot, tmp_path): batch_size = 4 seqlen = 512 dl = get_icl_task_dataloader('multiple_choice', - dataset_uri, - tokenizer, - batch_size, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, max_seq_len=seqlen, pad_tok_id=tokenizer.eos_token_id, num_fewshot=num_fewshot, @@ -483,13 +1155,13 @@ def test_qa_split_batch(tiny_opt_tokenizer, dataset_uri, tmp_path): tokenizer = tiny_opt_tokenizer tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) - gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) # for dist dl = get_icl_task_dataloader( - 'question_answering', - dataset_uri, - tokenizer, - 8, - max_seq_len=64, + icl_task_type='question_answering', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=8, + max_seq_len=1024, pad_tok_id=tokenizer.eos_token_id, num_fewshot=0, prompt_string='', @@ -570,9 +1242,9 @@ def test_qa_task_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path, num_fews # empirical number from the small test dataset maximum_answer_length = 7 dl = get_icl_task_dataloader('question_answering', - dataset_uri, - tokenizer, - batch_size, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, max_seq_len=seqlen, pad_tok_id=tokenizer.eos_token_id, num_fewshot=num_fewshot, @@ -622,9 +1294,9 @@ def test_qa_task_with_cot_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path, # empirical number from the small test dataset maximum_answer_length = 132 dl = get_icl_task_dataloader('question_answering', - dataset_uri, - tokenizer, - batch_size, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, max_seq_len=seqlen, pad_tok_id=tokenizer.eos_token_id, num_fewshot=num_fewshot, @@ -675,9 +1347,9 @@ def test_mc_task_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path): batch_size = 2 seqlen = 64 dl = get_icl_task_dataloader('multiple_choice', - dataset_uri, - tokenizer, - batch_size, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, max_seq_len=seqlen, pad_tok_id=tokenizer.eos_token_id, num_fewshot=1, @@ -714,16 +1386,18 @@ def test_code_eval_split_batch(dataset_uri, tmp_path): pytest.importorskip('datasets') local_data = os.path.join(os.path.dirname(__file__), 'local_data') dataset_uri = f'{local_data}/{dataset_uri}' - tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b') + transformers = pytest.importorskip('transformers') + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'EleutherAI/gpt-neox-20b') # type: ignore reportUnboundVariable tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) dl = get_icl_task_dataloader( 'code_evaluation', - dataset_uri, - tokenizer, - 8, - max_seq_len=64, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=8, + max_seq_len=1024, pad_tok_id=tokenizer.eos_token_id, num_fewshot=2, prompt_string='', @@ -755,7 +1429,6 @@ def test_code_eval_split_batch(dataset_uri, tmp_path): 'labels': str, 'prompts': str, 'tests': str, - 'canonical_solutions': str, 'entry_points': str, 'test_inputs': list, 'test_outputs': list, @@ -785,20 +1458,22 @@ def test_code_eval_sentpiece_dataloader(dataset_uri, tmp_path, num_fewshot, prom local_data = os.path.join(os.path.dirname(__file__), 'local_data') - tokenizer = AutoTokenizer.from_pretrained('huggyllama/llama-7b') + transformers = pytest.importorskip('transformers') + tokenizer = transformers.AutoTokenizer.from_pretrained('huggyllama/llama-7b') # type: ignore reportUnboundVariable dataset_uri = f'{local_data}/{dataset_uri}' batch_size = 4 seqlen = 2048 dl = get_icl_task_dataloader('code_evaluation', - dataset_uri, - tokenizer, - batch_size, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, max_seq_len=seqlen, pad_tok_id=tokenizer.eos_token_id, num_fewshot=num_fewshot, prompt_string=prompt_string, example_delimiter='\n', + continuation_delimiter='', question_prelimiter='Code start: \n', destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl'), generations_per_sample=generations_per_sample) @@ -850,20 +1525,22 @@ def test_code_eval_test_cases(dataset_uri, tmp_path): local_data = os.path.join(os.path.dirname(__file__), 'local_data') - tokenizer = AutoTokenizer.from_pretrained('huggyllama/llama-7b') + transformers = pytest.importorskip('transformers') + tokenizer = transformers.AutoTokenizer.from_pretrained('huggyllama/llama-7b') # type: ignore reportUnboundVariable dataset_uri = f'{local_data}/{dataset_uri}' batch_size = 4 seqlen = 512 dl = get_icl_task_dataloader('code_evaluation', - dataset_uri, - tokenizer, - batch_size, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, max_seq_len=seqlen, pad_tok_id=tokenizer.eos_token_id, num_fewshot=0, prompt_string='', example_delimiter='\n', + continuation_delimiter='', question_prelimiter='Code start: \n', destination_path=str(tmp_path / f'icl_.jsonl'), generations_per_sample=1) @@ -883,9 +1560,8 @@ def test_code_eval_test_cases(dataset_uri, tmp_path): assert any(item[0] != tokenizer.eos_token_id for item in batch['input_ids']) # longest should be pushed left mod = types.ModuleType('test_module') - for prompt, solution, inputs, outputs, entry_point in zip(batch['prompts'], batch['canonical_solutions'], - batch['test_inputs'], batch['test_outputs'], - batch['entry_points']): + for prompt, solution, inputs, outputs, entry_point in zip(batch['prompts'], batch['labels'], batch['test_inputs'], + batch['test_outputs'], batch['entry_points']): exec(prompt + solution, mod.__dict__) for test_input, test_output in zip(inputs, outputs): result = mod.__dict__[entry_point](*eval(test_input)) @@ -898,21 +1574,23 @@ def test_code_eval_pass_at_k_validity(dataset_uri, tmp_path): local_data = os.path.join(os.path.dirname(__file__), 'local_data') - tokenizer = AutoTokenizer.from_pretrained('huggyllama/llama-7b') + transformers = pytest.importorskip('transformers') + tokenizer = transformers.AutoTokenizer.from_pretrained('huggyllama/llama-7b') # type: ignore reportUnboundVariable dataset_uri = f'{local_data}/{dataset_uri}' batch_size = 2 seqlen = 64 with pytest.raises(ValueError, match=r'.* pass_at_k .*'): get_icl_task_dataloader('code_evaluation', - dataset_uri, - tokenizer, - batch_size, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, max_seq_len=seqlen, pad_tok_id=tokenizer.eos_token_id, num_fewshot=0, prompt_string='', example_delimiter='\n', + continuation_delimiter='', question_prelimiter='Code start: \n', destination_path=str(tmp_path / f'icl_.jsonl'), pass_at_k=10, @@ -928,23 +1606,29 @@ def test_code_eval_task_dataloader(dataset_uri, tmp_path, num_fewshot, prompt_st local_data = os.path.join(os.path.dirname(__file__), 'local_data') - tokenizer = AutoTokenizer.from_pretrained('mosaicml/mpt-7b') + transformers = pytest.importorskip('transformers') + tokenizer = transformers.AutoTokenizer.from_pretrained('mosaicml/mpt-7b') # type: ignore reportUnboundVariable dataset_uri = f'{local_data}/{dataset_uri}' batch_size = 4 seqlen = 2048 dl = get_icl_task_dataloader('code_evaluation', - dataset_uri, - tokenizer, - batch_size, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, max_seq_len=seqlen, pad_tok_id=tokenizer.eos_token_id, num_fewshot=num_fewshot, prompt_string=prompt_string, example_delimiter='\n', + continuation_delimiter='', question_prelimiter='Code start: \n', destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl'), - generations_per_sample=generations_per_sample) + generations_per_sample=generations_per_sample, + generation_kwargs={ + 'temperature': .9, + 'top_k': 40 + }) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -987,6 +1671,59 @@ def test_code_eval_task_dataloader(dataset_uri, tmp_path, num_fewshot, prompt_st ) +@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 1]) +def test_eval_split_batch(tiny_opt_tokenizer, dataset_uri, num_fewshot, tmp_path): + pytest.importorskip('datasets') + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + transformers = pytest.importorskip('transformers') + tokenizer = transformers.AutoTokenizer.from_pretrained('mosaicml/mpt-7b') # type: ignore reportUnboundVariable + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 4 + seqlen = 512 + + dl = get_icl_task_dataloader('code_evaluation', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter='', + question_prelimiter='Code start: \n', + destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl'), + generations_per_sample=1, + generation_kwargs={ + 'temperature': .9, + 'top_k': 40 + }) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + microbatch_size = 1 + microbatches = dl.split_batch(batch, microbatch_size) + assert len(microbatches) == 4 + for microbatch in microbatches: + assert dl.get_num_samples_in_batch(microbatch) == 1 + assert 'input_ids' in microbatch + # TODO: what should this be? + # assert tuple(microbatch['input_ids'].shape) == (microbatch_size, seqlen) + assert 'attention_mask' in microbatch + # assert tuple(microbatch['attention_mask'].shape) == (microbatch_size, seqlen) + assert isinstance(microbatch['generation_kwargs'], dict) + assert microbatch['generation_kwargs']['temperature'] == .9 + assert microbatch['generation_kwargs']['top_k'] == 40 + assert microbatch['generation_kwargs']['pad_token_id'] == 0 + assert microbatch['generation_kwargs']['num_beams'] == 1 + assert microbatch['generation_kwargs']['num_return_sequences'] == 1 + assert microbatch['generation_kwargs']['do_sample'] == True + assert microbatch['generation_kwargs']['use_cache'] == True + assert microbatch['generation_kwargs']['eos_token_id'] == 0 + + @pytest.mark.parametrize('dataset_uri', ['lambada_small.jsonl']) @pytest.mark.parametrize('num_fewshot', [0, 5]) @device('gpu') @@ -996,11 +1733,12 @@ def test_lm_task_evaluation(device, dataset_uri, num_fewshot, tiny_gpt2_tokenize local_data = os.path.join(os.path.dirname(__file__), 'local_data') dataset_uri = f'{local_data}/{dataset_uri}' tokenizer = tiny_gpt2_tokenizer + batch_size = 2 dl = get_icl_task_dataloader( 'language_modeling', - dataset_uri, - tokenizer, - 2, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, max_seq_len=2048, pad_tok_id=tokenizer.eos_token_id, num_fewshot=num_fewshot, @@ -1012,6 +1750,7 @@ def test_lm_task_evaluation(device, dataset_uri, num_fewshot, tiny_gpt2_tokenize evaluator = Evaluator(label='lambada', dataloader=dl, metric_names=['InContextLearningLMAccuracy']) + transformers = pytest.importorskip('transformers') config = transformers.AutoConfig.from_pretrained('EleutherAI/gpt-neo-125M') model = transformers.AutoModelForCausalLM.from_config(config) model = HuggingFaceModel( @@ -1027,8 +1766,8 @@ def test_lm_task_evaluation(device, dataset_uri, num_fewshot, tiny_gpt2_tokenize assert in_memory_logger.data['metrics/lambada/InContextLearningLMAccuracy'][0][1].item() == 0 -@pytest.mark.parametrize('dataset_uri', ['winograd_small.jsonl']) @pytest.mark.parametrize('num_fewshot', [0, 5]) +@pytest.mark.parametrize('dataset_uri', ['winograd_small.jsonl']) @pytest.mark.filterwarnings(r'ignore:Cannot split .* of length.*:UserWarning') def test_schema_task_evaluation(num_fewshot, dataset_uri, tiny_gpt2_tokenizer, tmp_path, tiny_gpt2_model): pytest.importorskip('datasets') @@ -1036,12 +1775,13 @@ def test_schema_task_evaluation(num_fewshot, dataset_uri, tiny_gpt2_tokenizer, t local_data = os.path.join(os.path.dirname(__file__), 'local_data') dataset_uri = f'{local_data}/{dataset_uri}' tokenizer = tiny_gpt2_tokenizer + batch_size = 8 dl = get_icl_task_dataloader( 'schema', - dataset_uri, - tokenizer, - 8, - max_seq_len=64, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=1024, pad_tok_id=tokenizer.eos_token_id, num_fewshot=num_fewshot, prompt_string='', @@ -1082,13 +1822,16 @@ def test_mc_task_evaluation_subcategories(device, world_size, dataset_uri, num_f local_data = os.path.join(os.path.dirname(__file__), 'local_data') dataset_uri = f'{local_data}/{dataset_uri}' tokenizer = tiny_gpt2_tokenizer + batch_size = 8 + max_seq_len = 64 tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + reproducibility.seed_all(1234) dls = get_icl_task_dataloader('multiple_choice', - dataset_uri, - tokenizer, - 8, - max_seq_len=64, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=max_seq_len, pad_tok_id=tokenizer.eos_token_id, num_fewshot=num_fewshot, prompt_string='', @@ -1121,29 +1864,35 @@ def test_mc_task_evaluation_subcategories(device, world_size, dataset_uri, num_f @pytest.mark.parametrize('dataset_uri', ['piqa_small.jsonl', 'hellaswag_small.jsonl']) -@device('gpu') @pytest.mark.parametrize('num_fewshot', [0, 5]) -def test_mc_task_evaluation(device, num_fewshot, dataset_uri, tiny_gpt2_tokenizer, tmp_path, tiny_gpt2_model): +@pytest.mark.filterwarnings(r'ignore:Cannot split .* of length.*:UserWarning') +@device('gpu') +@world_size(1, 2) +def test_mc_task_evaluation(device, world_size, num_fewshot, dataset_uri, tiny_gpt2_tokenizer, tmp_path, + tiny_gpt2_model): pytest.importorskip('datasets') in_memory_logger = InMemoryLogger() # track the logged metrics in the in_memory_logger local_data = os.path.join(os.path.dirname(__file__), 'local_data') dataset_uri = f'{local_data}/{dataset_uri}' tokenizer = tiny_gpt2_tokenizer + batch_size = 8 + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) # seed because the fewshot selection is currently unseeded reproducibility.seed_all(1234) dl = get_icl_task_dataloader( 'multiple_choice', - dataset_uri, - tokenizer, - 8, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, max_seq_len=64, pad_tok_id=tokenizer.eos_token_id, num_fewshot=num_fewshot, prompt_string='', example_delimiter='\n', continuation_delimiter=': ', - destination_path=str(tmp_path / 'icl.jsonl'), + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), ) evaluator = Evaluator(label='mc', dataloader=dl, metric_names=['InContextLearningMultipleChoiceAccuracy']) @@ -1163,14 +1912,17 @@ def test_mc_task_evaluation(device, num_fewshot, dataset_uri, tiny_gpt2_tokenize with open(dataset_uri) as f: for _ in f: num_samples += 1 - assert trainer.state.eval_metrics['mc']['InContextLearningMultipleChoiceAccuracy'].total == num_samples + total = trainer.state.eval_metrics['mc']['InContextLearningMultipleChoiceAccuracy'].total + dist.all_reduce(total) # type: ignore + assert total.item() == num_samples # type: ignore +@pytest.mark.parametrize('num_fewshot', [0, 5]) @pytest.mark.parametrize('dataset_uri', ['triviaqa_small.jsonl']) +@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning') +@pytest.mark.filterwarnings(r'ignore:Cannot split .* of length.*:UserWarning') @device('gpu') @world_size(1, 2) -@pytest.mark.parametrize('num_fewshot', [0, 5]) -@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning') def test_qa_task_evaluation_opt_tokenizer(device, world_size, tiny_opt_tokenizer, tiny_opt_model, num_fewshot, dataset_uri, tmp_path): pytest.importorskip('datasets') @@ -1179,14 +1931,15 @@ def test_qa_task_evaluation_opt_tokenizer(device, world_size, tiny_opt_tokenizer dataset_uri = f'{local_data}/{dataset_uri}' tokenizer = tiny_opt_tokenizer + batch_size = 4 tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) dl = get_icl_task_dataloader( 'question_answering', - dataset_uri, - tokenizer, - 2, - max_seq_len=64, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=1024, pad_tok_id=tokenizer.eos_token_id, num_fewshot=num_fewshot, prompt_string='', @@ -1210,11 +1963,12 @@ def test_qa_task_evaluation_opt_tokenizer(device, world_size, tiny_opt_tokenizer assert in_memory_logger.data['metrics/triviaqa/InContextLearningQAAccuracy'][0][1].item() == 0 +@pytest.mark.parametrize('num_fewshot', [5]) @pytest.mark.parametrize('dataset_uri', ['gsm8k_small.jsonl']) @device('gpu') @world_size(1, 2) -@pytest.mark.parametrize('num_fewshot', [5]) @pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning') +@pytest.mark.filterwarnings(r'ignore:Cannot split .* of length.*:UserWarning') def test_qa_task_evaluation_with_cot_opt_tokenizer(device, world_size, tiny_opt_tokenizer, tiny_opt_model, num_fewshot, dataset_uri, tmp_path): pytest.importorskip('datasets') @@ -1223,14 +1977,15 @@ def test_qa_task_evaluation_with_cot_opt_tokenizer(device, world_size, tiny_opt_ dataset_uri = f'{local_data}/{dataset_uri}' tokenizer = tiny_opt_tokenizer + batch_size = 4 tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) dl = get_icl_task_dataloader( 'question_answering', - dataset_uri, - tokenizer, - 2, - max_seq_len=256, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=1024, pad_tok_id=tokenizer.eos_token_id, num_fewshot=num_fewshot, prompt_string='', @@ -1256,9 +2011,9 @@ def test_qa_task_evaluation_with_cot_opt_tokenizer(device, world_size, tiny_opt_ @pytest.mark.parametrize('dataset_uri', ['triviaqa_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 5]) @device('gpu') @world_size(1, 2) -@pytest.mark.parametrize('num_fewshot', [0, 5]) @pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning') def test_qa_task_evaluation(device, world_size, num_fewshot, dataset_uri, tiny_gpt2_tokenizer, tiny_gpt2_model, tmp_path): @@ -1267,14 +2022,15 @@ def test_qa_task_evaluation(device, world_size, num_fewshot, dataset_uri, tiny_g local_data = os.path.join(os.path.dirname(__file__), 'local_data') dataset_uri = f'{local_data}/{dataset_uri}' tokenizer = tiny_gpt2_tokenizer + batch_size = 2 tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) dl = get_icl_task_dataloader( 'question_answering', - dataset_uri, - tokenizer, - 2, - max_seq_len=64, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=1024, pad_tok_id=tokenizer.eos_token_id, num_fewshot=num_fewshot, prompt_string='', @@ -1300,10 +2056,10 @@ def test_qa_task_evaluation(device, world_size, num_fewshot, dataset_uri, tiny_g @pytest.mark.parametrize('dataset_uri', ['gsm8k_small.jsonl']) -@device('gpu') -@world_size(1, 2) @pytest.mark.parametrize('num_fewshot', [5]) @pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning') +@device('gpu') +@world_size(1, 2) def test_qa_task_with_cot_evaluation(device, world_size, num_fewshot, dataset_uri, tiny_gpt2_tokenizer, tiny_gpt2_model, tmp_path): pytest.importorskip('datasets') @@ -1311,14 +2067,15 @@ def test_qa_task_with_cot_evaluation(device, world_size, num_fewshot, dataset_ur local_data = os.path.join(os.path.dirname(__file__), 'local_data') dataset_uri = f'{local_data}/{dataset_uri}' tokenizer = tiny_gpt2_tokenizer + batch_size = 2 tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) dl = get_icl_task_dataloader( 'question_answering', - dataset_uri, - tokenizer, - 2, - max_seq_len=256, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=1024, pad_tok_id=tokenizer.eos_token_id, num_fewshot=num_fewshot, prompt_string='', @@ -1357,10 +2114,10 @@ def test_code_eval_requires_valid_envvar(monkeypatch): @pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0]) +@pytest.mark.parametrize('generations_per_sample', range(1, 3)) @device('gpu') @world_size(1, 2) -@pytest.mark.parametrize('num_fewshot', [0]) -@pytest.mark.parametrize('generations_per_sample', [1, 2]) @pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning') def test_code_eval_microbatching(monkeypatch, device, world_size, tiny_opt_tokenizer, tiny_opt_model, num_fewshot, dataset_uri, tmp_path, generations_per_sample): @@ -1370,15 +2127,16 @@ def test_code_eval_microbatching(monkeypatch, device, world_size, tiny_opt_token local_data = os.path.join(os.path.dirname(__file__), 'local_data') dataset_uri = f'{local_data}/{dataset_uri}' tokenizer = tiny_opt_tokenizer + batch_size = 4 tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) dl = get_icl_task_dataloader( 'code_evaluation', - dataset_uri, - tokenizer, - 2, - max_seq_len=256, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=150, pad_tok_id=tokenizer.eos_token_id, num_fewshot=num_fewshot, prompt_string='', @@ -1408,10 +2166,10 @@ def test_code_eval_microbatching(monkeypatch, device, world_size, tiny_opt_token @pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0]) +@pytest.mark.parametrize('generations_per_sample', range(1, 3)) @device('gpu') @world_size(1, 2) -@pytest.mark.parametrize('num_fewshot', [0]) -@pytest.mark.parametrize('generations_per_sample', [1, 2]) @pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning') def test_code_eval_sentpiece_evaluation(monkeypatch, device, world_size, num_fewshot, dataset_uri, tiny_t5_tokenizer, tiny_t5_model, tmp_path, generations_per_sample): @@ -1421,14 +2179,15 @@ def test_code_eval_sentpiece_evaluation(monkeypatch, device, world_size, num_few local_data = os.path.join(os.path.dirname(__file__), 'local_data') dataset_uri = f'{local_data}/{dataset_uri}' tokenizer = tiny_t5_tokenizer + batch_size = 2 tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) dl = get_icl_task_dataloader( 'code_evaluation', - dataset_uri, - tokenizer, - 2, - max_seq_len=256, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=175, pad_tok_id=tokenizer.eos_token_id, num_fewshot=num_fewshot, prompt_string='', @@ -1455,11 +2214,11 @@ def test_code_eval_sentpiece_evaluation(monkeypatch, device, world_size, num_few @pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl']) -@device('gpu') -@world_size(1, 2) @pytest.mark.parametrize('num_fewshot', [0, 2]) @pytest.mark.parametrize('generations_per_sample', [1]) @pytest.mark.filterwarnings(r'ignore: Input length of input_ids is') +@device('gpu') +@world_size(1, 2) @pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning') def test_code_eval_task_evaluation(monkeypatch, device, world_size, num_fewshot, dataset_uri, tiny_gpt2_tokenizer, tiny_gpt2_model, tmp_path, generations_per_sample): @@ -1469,13 +2228,14 @@ def test_code_eval_task_evaluation(monkeypatch, device, world_size, num_fewshot, local_data = os.path.join(os.path.dirname(__file__), 'local_data') dataset_uri = f'{local_data}/{dataset_uri}' tokenizer = tiny_gpt2_tokenizer + batch_size = 2 tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) dl = get_icl_task_dataloader( 'code_evaluation', - dataset_uri, - tokenizer, - 2, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, max_seq_len=64 * num_fewshot, pad_tok_id=tokenizer.eos_token_id, num_fewshot=num_fewshot, @@ -1513,9 +2273,9 @@ def test_lm_spacing_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path): batch_size = 2 seqlen = 512 dl = get_icl_task_dataloader('language_modeling', - dataset_uri, - tokenizer, - batch_size, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, max_seq_len=seqlen, pad_tok_id=tokenizer.eos_token_id, num_fewshot=1, @@ -1539,3 +2299,112 @@ def test_lm_spacing_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path): assert first_batch_without_last_word.count(' UNIQUE ') == 1 assert second_batch_without_last_word.count(' UNIQUE ') == 1 + + +@pytest.mark.parametrize('dataset_uri', ['hf://mosaicml/test_dataset']) +@pytest.mark.parametrize('num_fewshot', [0, 1]) +@pytest.mark.parametrize('prompt_string', ['Complete the voiceline: ', '']) +@pytest.mark.parametrize('hf_loading_vars', [{ + 'split': 'test', + 'name': 'juggernaut', +}]) +@pytest.mark.parametrize('hf_parsing_map', [None, {'context': ['context'], 'continuation': ['continuation']}]) +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning') +def test_hf_dataloading_lm_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path, num_fewshot, prompt_string, + hf_loading_vars, hf_parsing_map): + pytest.importorskip('datasets') + + tokenizer = tiny_gpt2_tokenizer + batch_size = 2 + seqlen = 2048 + dl = get_icl_task_dataloader('language_modeling', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=' ', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + + assert 'input_ids' in batch + assert tuple(batch['input_ids'].shape) == (batch_size, seqlen) + assert 'attention_mask' in batch + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) + assert 'continuation_indices' in batch + assert isinstance(batch['continuation_indices'], list) and len(batch['continuation_indices']) == batch_size + assert 'mode' in batch + assert batch['mode'] == 'icl_task' + min_idx = min(batch['continuation_indices'][0]).item() + max_idx = max(batch['continuation_indices'][0]).item() + assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + 1]) == ' and me.' + + decoded_batch = [tokenizer.decode(row[row != tokenizer.eos_token_id]) for row in batch['input_ids']] + assert decoded_batch[0] == "Looks like it's just you and me." + assert decoded_batch[1] == "There's a fine line between bravery and stupidity." + + +@pytest.mark.parametrize('dataset_uri', ['hf://mosaicml/test_dataset']) +@pytest.mark.parametrize('num_fewshot', [0, 1]) +@pytest.mark.parametrize('prompt_string', ['What spell does this invoke? ', '']) +@pytest.mark.parametrize('hf_loading_vars', [{ + 'split': 'test', + 'name': 'invoker', +}]) +@pytest.mark.parametrize('hf_parsing_map', [{'context': ['quas', 'wex', 'exort'], 'answer': ['spell']}]) +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning') +def test_hf_dataloading_custom_parsing(dataset_uri, tiny_gpt2_tokenizer, tmp_path, num_fewshot, prompt_string, + hf_loading_vars, hf_parsing_map): + pytest.importorskip('datasets') + + tokenizer = tiny_gpt2_tokenizer + batch_size = 2 + seqlen = 2048 + + # empirical number from the small test dataset + maximum_answer_length = 4 + + dl = get_icl_task_dataloader('question_answering', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter='\n', + question_prelimiter='Orbs: ', + continuation_delimiter='\nSpell:', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + + assert tuple(batch['input_ids'].shape) == (batch_size, seqlen - maximum_answer_length) + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen - maximum_answer_length) + assert batch['mode'] == 'generate' + # the maximum generation length from the small test data + assert batch['generation_length'] == maximum_answer_length + assert all(item[0] == tokenizer.eos_token_id for item in batch['input_ids']) + + decoded_batch = tokenizer.batch_decode(batch['input_ids']) + assert all(item.count('Orbs: ') == num_fewshot + 1 for item in decoded_batch) + assert all(item.count('\nSpell:') == num_fewshot + 1 for item in decoded_batch) + + if len(prompt_string) > 0: + assert all(item.count('What spell does this invoke? ') == 1 for item in decoded_batch) + assert all( + set(found) == set(expected) for found, expected in zip(batch['labels'], [['defeaning blast'], ['cold snap']])) + assert decoded_batch[0].endswith('Orbs: quas wex exort\nSpell:') + assert decoded_batch[1].endswith('Orbs: quas quas quas\nSpell:')