Skip to content

Commit

Permalink
Migrate ICL classes to foundry (#936)
Browse files Browse the repository at this point in the history
* start

* still need to migrate fixtures

* wip onboarding tests

* still workin'

* still wip

* maybe done; test out on mcli now

* mcli

* remove calibration error

* migration

* migration

* full migration

* precommit

* fix

* fix pytests

* refactor QA

* update

* restore

* add

* fix

* wip

* update readme

* final pyright

* done

* pass prelimiter into ALL the ICL task datasets

* allow QA task name stil lfor backward compatibility

* fix

* fix test

* add generation length

* remove max_new_tokens

* fix cpu trsts

* try and fix lm eval test

* temp disable lm task eval test

* fix test?

* fix tet

* finish

* fix

* Update scripts/eval/README.md

Co-authored-by: Daniel King <[email protected]>

* fix comments

* fix bug with seq len

* restore mcli

* merge

* fix builder

* add deprecation warning

* add deprecation warning

* merge

* merge

* add logging necessities to nlp.py

* add attention_mask test update

* fix generation_length in tests

* fix bug

* restore yamls

* fix typos

* add deprecation warning for code

* pyright wip

* fix pyright

* fix pyright error again

* fix pyright

* fix pyright

* update version

---------

Co-authored-by: Eitan Turok <[email protected]>
Co-authored-by: Max Marion <[email protected]>
Co-authored-by: Daniel King <[email protected]>
Co-authored-by: Max Marion <[email protected]>
  • Loading branch information
5 people authored Apr 12, 2024
1 parent 6257e5b commit 3729ba3
Show file tree
Hide file tree
Showing 29 changed files with 6,107 additions and 47 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
my-copy-c4*/
my-copy-arxiv*/
*.jsonl*
!tests/eval/local_data/*.jsonl

# WandB
wandb/
Expand Down
31 changes: 31 additions & 0 deletions llmfoundry/eval/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Natively supported in-context learning evaluation datasets."""

from llmfoundry.eval.datasets.in_context_learning_evaluation import (
InContextLearningCodeEvalDataset, InContextLearningDataset,
InContextLearningGenerationTaskWithAnswersDataset,
InContextLearningLMTaskDataset, InContextLearningMultipleChoiceTaskDataset,
InContextLearningSchemaTaskDataset, get_icl_task_dataloader)
from llmfoundry.eval.datasets.utils import (get_continuation_span,
get_fewshot_sample_idxs,
make_padded_input, strip_data,
tokenizer_needs_prefix_space,
trim_context)

__all__ = [
'InContextLearningDataset',
'InContextLearningGenerationTaskWithAnswersDataset',
'InContextLearningLMTaskDataset',
'InContextLearningCodeEvalDataset',
'InContextLearningMultipleChoiceTaskDataset',
'InContextLearningSchemaTaskDataset',
'get_icl_task_dataloader',
'strip_data',
'tokenizer_needs_prefix_space',
'trim_context',
'get_continuation_span',
'get_fewshot_sample_idxs',
'make_padded_input',
]
1,791 changes: 1,791 additions & 0 deletions llmfoundry/eval/datasets/in_context_learning_evaluation.py

Large diffs are not rendered by default.

277 changes: 277 additions & 0 deletions llmfoundry/eval/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Utility and helper functions for datasets."""
from __future__ import annotations

import logging
import random
from typing import Any, Dict, List, Optional, Set

import torch
import transformers

__all__ = [
'MultiTokenEOSCriteria',
]

log = logging.getLogger(__name__)


def strip_data(example: Dict) -> Dict:
"""Remove white space from the begging and end of string values in a.
dictionary.
Args:
example: Dictionary to be stripped
Returns:
dict: The same dictionary with .strip() applied to any value in the dict that is a string
"""
return {
k: v.strip() if isinstance(v, str) else v for k, v in example.items()
}


def tokenizer_needs_prefix_space(
tokenizer: transformers.PreTrainedTokenizerBase) -> bool:
"""Test for whether a prefix space is needed before the continuation.
Sentencepiece tokenization should not have a prefix space, but gpt2 style
BPE should.
Args:
tokenizer: Tokenizer to test
Returns:
bool: Whether or not the tokenizer needs a prefix space
"""
test_tokens = tokenizer(' a', add_special_tokens=False)['input_ids']
assert isinstance(test_tokens, list)
return len(test_tokens) == 1


def trim_context(context_enc: List, continuation_enc: List,
max_seq_len: int) -> List:
"""Trims a list of tokens down to `max_seq_len` if the length of the list.
plus the continuation is more than `max_seq_len`. It will always trim tokens
from the left, i.e. tokens at the beginning of the context will be removed.
Args:
context_enc (list): List of tokens in the context
continuation_enc (lsit): List of tokens in the continuation
max_seq_len (int): Maximum length the model can ingest
Returns:
list: The encoded context trimmed from the left
"""
if len(continuation_enc) + len(context_enc) > max_seq_len:
context_max_subseq_len = max_seq_len - len(continuation_enc)

if context_max_subseq_len < 0:
# can't support continuations which are longer than the max seq len
raise Exception(
f'Dataset included continuation longer than the max seq len')

# clip from the end
context_enc = context_enc[-(context_max_subseq_len):]
return context_enc


def get_continuation_span(context_enc: List,
continuation_enc: List) -> torch.Tensor:
"""Gets the list of indices of the continuation tokens for language.
modeling.
or generation tasks.
Args:
context_enc (list): List of context tokens
continuation_enc (list): List of continuation tokens
Returns:
torch.tensor: A tensor containing indices corresponding to continuation tokens
"""
return torch.tensor(
range(len(context_enc),
len(context_enc) + len(continuation_enc)))


def make_padded_input(context_enc: List,
continuation_enc: List,
max_seq_len: int,
pad_tok_id: int,
padding_side: str = 'right') -> torch.Tensor:
"""Takes an encoded context and continuation and clips the beginning of the.
context if they're too long. Adds the padding token to the specified side.
Args:
context_enc (List): The encoded input to the model
continuation_enc (List): The encoded desired output for the example
max_seq_list (int): Maximum length sequences can be
pad_tok_id (int): The token id we pad with
padding_side (str): Which side to pad the context on. Can be 'right' or 'left
Returns:
input (torch.tensor): The padded and encoded context
continuation_span (torch.tensor): The _inclusive_ range of indices corresponding to the continuation
"""
inp = torch.tensor(
(context_enc + continuation_enc),
dtype=torch.long,
)
(inp_len,) = inp.shape

# Sometimes tokenizers that have neither a pad_tok_id or eos_tok_id will pass None in as the padding
# token and cause errors
if not isinstance(pad_tok_id, int):
raise ValueError(
f'`pad_tok_id` must be an integer. Found {type(pad_tok_id)} instead'
)
# pad length from seq to padding_length
if padding_side == 'right':
inp = torch.cat(
[
inp, # [seq]
torch.LongTensor((max_seq_len - inp_len) * [pad_tok_id]),
],
dim=0,
)
elif padding_side == 'left':
inp = torch.cat(
[
torch.LongTensor((max_seq_len - inp_len) * [pad_tok_id]),
inp, # [seq]
],
dim=0,
)
else:
raise ValueError(
f"Unknown padding_side {padding_side}. padding_side must be either 'left' or 'right'"
)

return inp


def convert_tokens_to_tensors(batch: Dict,
tokenize_labels: bool) -> Dict[str, Any]:
"""HF Datasets converts tensors into lists when we store them, and we don't.
want to use `type='torch'` because some content in the dataset, like
generation args or single ints, should not be converted.
Here, we convert those lists of tokens back into tensors in order to feed them into the model.
Args:
batch (dict): A dictionary of batched inputs
tokenize_labels (bool): Whether or not the labels are tokenized (and need to be stacked)
Returns:
dict: The batch with torch tensors in the corresponding keys instead of lists of lists
"""
batch['input_ids'] = torch.stack(list(map(torch.tensor,
batch['input_ids'])))
if tokenize_labels:
batch['labels'] = torch.stack(list(map(torch.tensor, batch['labels'])))
batch['continuation_indices'] = list(
map(torch.tensor, batch['continuation_indices']))
return batch


def get_fewshot_sample_idxs(dataset_size: int, num_fewshot: int,
example_idx: int, rng: random.Random) -> Set[int]:
"""Samples indices without replacement. If num_fewshot exceeds the number.
of unique examples in the dataset, then we will have fewer than num_fewshot examples in context.
Args:
dataset_size (int): Length of the dataset
num_fewshot (int): Number of examples to prepend
example_idx (int): Current example's index (excluded from fewshot choices)
rng (random.Random): RNG for repeatable sample selection
Returns:
list: Indices of the examples chosen for fewshot selection
"""
num_fewshot = min(dataset_size - 1, num_fewshot)
fewshot_idxs = set(rng.sample(range(0, dataset_size), num_fewshot))

if example_idx in fewshot_idxs:
fewshot_idxs.remove(example_idx)
if len(fewshot_idxs) >= dataset_size - 1:
return fewshot_idxs

replacement_sample = rng.choice(range(0, dataset_size))
while replacement_sample in fewshot_idxs or replacement_sample == example_idx:
replacement_sample = rng.choice(range(0, dataset_size))
fewshot_idxs.add(replacement_sample)
return fewshot_idxs


class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence.
Slightly modified from: https://github.com/EleutherAI/lm-evaluation-harness/blob/78545d42f2ca95c6fe0ed220d456eeb94f4485e9/lm_eval/utils.py#L614-L649
"""

def __init__(
self,
stop_sequence: str,
tokenizer: transformers.PreTrainedTokenizerBase,
batch_size: int,
) -> None:
self.done_tracker = [False] * batch_size
self.stop_sequence = stop_sequence
self.stop_sequence_ids = tokenizer.encode(stop_sequence,
add_special_tokens=False)

# sentence piece tokenizers add a superflous underline token before string-initial \n
# that throws off our calculation of the stop sequence length
# so we remove any token ids that produce empty strings
self.stop_sequence_ids = [
id for id in self.stop_sequence_ids if tokenizer.decode(id) != ''
]

# we look back for 1 more token than it takes to encode our stop sequence
# because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']`
# and we don't want to mistakenly not stop a generation because our
# (string) stop sequence was output in a different tokenization

self.stop_sequence_id_len = len(self.stop_sequence_ids) + 1
self.tokenizer = tokenizer

def __call__(self,
input_ids: torch.LongTensor,
scores: Optional[torch.FloatTensor] = None,
**kwargs: Dict[str, Any]) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, :][:, -self.stop_sequence_id_len:]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if i >= len(lookback_tokens_batch):
# The last batch of a dataset may be smaller than `batch_size`
# Automatically set those indices in the done_tracker to True
# since those indices don't show up in the current batch
self.done_tracker[i] = True
break
elif not done:
self.done_tracker[
i] = self.stop_sequence in lookback_tokens_batch[i]
return False not in self.done_tracker


def stop_sequences_criteria(
tokenizer: transformers.PreTrainedTokenizerBase,
stop_sequences: List[str],
batch_size: int,
) -> transformers.StoppingCriteriaList:
return transformers.StoppingCriteriaList([
*[
MultiTokenEOSCriteria(sequence, tokenizer, batch_size)
for sequence in stop_sequences
],
])
21 changes: 21 additions & 0 deletions llmfoundry/eval/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""A collection of common torchmetrics."""

from llmfoundry.eval.metrics.nlp import (
InContextLearningCodeEvalAccuracy,
InContextLearningGenerationExactMatchAccuracy, InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError, InContextLearningMetric,
InContextLearningMultipleChoiceAccuracy)

__all__ = [
'InContextLearningLMAccuracy',
'InContextLearningMultipleChoiceAccuracy',
'InContextLearningGenerationExactMatchAccuracy',
'InContextLearningMCExpectedCalibrationError',
'InContextLearningLMExpectedCalibrationError',
'InContextLearningMetric',
'InContextLearningCodeEvalAccuracy',
]
Loading

0 comments on commit 3729ba3

Please sign in to comment.