Skip to content

Commit

Permalink
pass 1 after code review
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier committed Feb 7, 2024
1 parent cc1e7dd commit 9b0fe54
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 113 deletions.
5 changes: 4 additions & 1 deletion src/lighteval/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def __init__(
requests (List): A list of requests.
dataset_splits (int): The number of dataset splits.
"""
# We make sure the requests contain the tokenized versions of their values
if any(r.tokenized_context is None for r in requests):
raise ValueError("You passed a request for which tokenization had not happened yet.")

# sort the requests using the collate function and save the original order
enumerated_requests = list(enumerate(requests))
Expand Down Expand Up @@ -190,7 +193,7 @@ def _sorting_criteria(self, request: GreedyUntilRequest | GreedyUntilWithLogitsR
Returns:
Any: The collated data.
"""
toks = request.context
toks = request.tokenized_context
gen_length = request.generation_size
return -(len(toks) + gen_length)

Expand Down
1 change: 1 addition & 0 deletions src/lighteval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def apply_multichoice_metric(results: list[ModelReturn], formatted_doc: Doc, met
"You can't use a multi choice metric with only one choice. Use `acc_golds_likelihood` instead."
)

# Todo: make better system with return_bool_score instead of taking first element
choices_logprob = [results[i].result[0] for i in range(len(formatted_doc.choices))] # sum(
gold_ixs = as_list(formatted_doc.gold_index)

Expand Down
62 changes: 52 additions & 10 deletions src/lighteval/models/abstract_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from abc import ABC, abstractmethod
from typing import Optional
from typing import Optional, Union

import torch
from transformers import BatchEncoding

from lighteval.models.model_config import EnvConfig
from lighteval.models.model_output import GenerateReturn, LoglikelihoodReturn, LoglikelihoodSingleTokenReturn
Expand All @@ -12,7 +15,12 @@
)


TokenSequence = Union[list[int], torch.LongTensor, torch.Tensor, BatchEncoding]


class LightevalModel(ABC):
DATASET_SPLITS = 4

"""Abstract model class defining the API that every model to plug into lighteval must follow."""

@abstractmethod
Expand All @@ -21,24 +29,36 @@ def __init__(
config,
env_config: EnvConfig,
):
self.tokenizer = None
return NotImplemented

def cleanup(self):
"""Clean up operations if needed, such as closing an endpoint."""
return

@property
@abstractmethod
def tokenizer(self):
raise NotImplementedError

@property
@abstractmethod
def add_special_tokens(self):
raise NotImplementedError

@property
@abstractmethod
def max_length(self) -> int:
"""Return the maximum sequence length of the model."""
raise NotImplementedError

@property
def disable_tqdm(self) -> bool:
raise NotImplementedError

def greedy_until_with_logits(
self,
requests: list[GreedyUntilWithLogitsRequest],
disable_tqdm: bool = False,
override_bs: Optional[int] = None,
dataset_splits: int = 4,
) -> list[GenerateReturn]:
"""
Generates sequences greedily until a stopping condition is met,
Expand All @@ -49,17 +69,14 @@ def greedy_until_with_logits(
where each request is a tuple containing a prompt string and a dictionary of additional parameters.
disable_tqdm (bool, optional): Whether to disable the tqdm progress bar. Defaults to False.
override_bs (Optional[int], optional): Overrides the batch size for generation. Defaults to None.
dataset_splits (int, optional): Number of splits to divide the dataset into for parallel generation. Defaults to 4.
Returns:
list[GenerateReturn]: A list of GenerateReturn objects,
where each object contains the generated sequence and the corresponding logits.
"""
return self.greedy_until(
requests=requests,
disable_tqdm=disable_tqdm,
override_bs=override_bs,
dataset_splits=dataset_splits,
returns_logits=True,
)

Expand All @@ -68,9 +85,7 @@ def greedy_until(
self,
requests: list[GreedyUntilRequest],
returns_logits: bool = False,
disable_tqdm: bool = False,
override_bs: Optional[int] = None,
dataset_splits: int = 4,
) -> list[GenerateReturn]:
"""
Generates responses using a greedy decoding strategy until certain ending conditions are met.
Expand All @@ -80,7 +95,6 @@ def greedy_until(
returns_logits (bool, optional): Whether to return the logits of the generated responses. Defaults to False.
disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False.
override_bs (int, optional): Override the batch size for generation. Defaults to None.
dataset_splits (int, optional): Number of splits to divide the dataset into. Defaults to 4.
Returns:
list[GenerateReturn]: list of generated responses.
Expand Down Expand Up @@ -111,3 +125,31 @@ def loglikelihood_single_token(
tokenized sequences.
"""
return NotImplemented

# Tokenization utils
def tok_encode(self, str_to_encode: str | list[str], add_special_tokens: Optional[bool] = None) -> TokenSequence:
if add_special_tokens is None:
add_special_tokens = self.add_special_tokens
if isinstance(str_to_encode, str):
return self.tokenizer.encode(str_to_encode, add_special_tokens=add_special_tokens)
return self.tokenizer(
str_to_encode,
padding=True,
add_special_tokens=add_special_tokens,
return_tensors="pt",
)

def tok_encode_pair(self, context, continuation):
"""Encodes a context, continuation pair by taking care of the spaces in between."""
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc

def tok_decode(self, tokens: torch.LongTensor) -> list[str]:
return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)
88 changes: 31 additions & 57 deletions src/lighteval/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import transformers
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding
from transformers import AutoModelForCausalLM, AutoTokenizer

from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset, LoglikelihoodSingleTokenDataset
from lighteval.logging.hierarchical_logger import hlog, hlog_err, hlog_warn
Expand All @@ -32,9 +32,6 @@

os.environ["TOKENIZERS_PARALLELISM"] = "false"

TokenSequence = Union[list[int], torch.LongTensor, torch.Tensor, BatchEncoding]

DATASET_SPLITS = 4
STARTING_BATCH_SIZE = 512


Expand All @@ -50,8 +47,8 @@ def __init__(
self._batch_size = config.batch_size
self._max_length = self._init_max_length(config.max_length)

self.add_special_tokens = config.add_special_tokens if config.add_special_tokens is not None else False
self.tokenizer = self._create_auto_tokenizer(config, env_config)
self._add_special_tokens = config.add_special_tokens if config.add_special_tokens is not None else False
self._tokenizer = self._create_auto_tokenizer(config, env_config)

# If model_parallel is not set we compare the number of process with the number of GPUs
self.model = self._create_auto_model(config, env_config)
Expand All @@ -73,6 +70,18 @@ def __init__(

self.precision = _get_precision(config, model_auto_config=self._config)

@property
def tokenizer(self):
return self._tokenizer

@property
def add_special_tokens(self):
return self._add_special_tokens

@property
def max_length(self) -> int:
return self._max_length

def init_model_parallel(self, model_parallel: bool = None) -> Tuple[bool, Optional[dict], Optional[str]]:
"""Compute all the parameters related to model_parallel"""
if not is_accelerate_available():
Expand Down Expand Up @@ -203,10 +212,6 @@ def _create_auto_tokenizer_with_name(

return tokenizer

@property
def max_length(self) -> int:
return self._max_length

def _init_max_length(self, max_length) -> int:
"""Return the maximum sequence length of the model.
NOTE: Different model configurations have different max sequence length
Expand Down Expand Up @@ -256,33 +261,6 @@ def disable_tqdm(self) -> bool:
disable_tqdm = bool(not self.accelerator.is_main_process)
return disable_tqdm

# Tokenization helpers
def tok_encode_pair(self, context, continuation):
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc

def tok_encode(self, str_to_encode: str | list[str], add_special_tokens: Optional[bool] = None) -> TokenSequence:
if add_special_tokens is None:
add_special_tokens = self.add_special_tokens
if isinstance(str_to_encode, str):
return self.tokenizer.encode(str_to_encode, add_special_tokens=add_special_tokens)
return self.tokenizer(
str_to_encode,
padding=True,
add_special_tokens=add_special_tokens,
return_tensors="pt",
)

def tok_decode(self, tokens: torch.LongTensor) -> list[str]:
return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)

def _check_continuations_start_space(self, continuation: str) -> str:
"""Some models tokenizer want a space at the beginning and other not. We update this if needed here.
multichoice_continuations_start_space can be:
Expand Down Expand Up @@ -323,7 +301,6 @@ def greedy_until_with_logits(
self,
requests: list[GreedyUntilWithLogitsRequest],
override_bs: Optional[int] = None,
dataset_splits: int = 4,
) -> list[GenerateReturn]:
"""
Generates sequences greedily until a stopping condition is met,
Expand All @@ -333,7 +310,6 @@ def greedy_until_with_logits(
requests (list[tuple[str, dict]]): A list of input requests,
where each request is a tuple containing a prompt string and a dictionary of additional parameters.
override_bs (Optional[int], optional): Overrides the batch size for generation. Defaults to None.
dataset_splits (int, optional): Number of splits to divide the dataset into for parallel generation. Defaults to 4.
Returns:
list[GenerateReturn]: A list of GenerateReturn objects,
Expand All @@ -345,15 +321,13 @@ def greedy_until_with_logits(
returns_logits=True,
disable_tqdm=self.disable_tqdm,
override_bs=override_bs,
dataset_splits=dataset_splits,
)

def greedy_until(
self,
requests: list[GreedyUntilRequest],
returns_logits: bool = False,
override_bs: Optional[int] = None,
dataset_splits: int = 4,
) -> list[GenerateReturn]:
"""
Generates responses using a greedy decoding strategy until certain ending conditions are met.
Expand All @@ -362,34 +336,33 @@ def greedy_until(
requests (list[Request]): list of requests containing the context and ending conditions.
returns_logits (bool, optional): Whether to return the logits of the generated responses. Defaults to False.
override_bs (int, optional): Override the batch size for generation. Defaults to None.
dataset_splits (int, optional): Number of splits to divide the dataset into. Defaults to 4.
Returns:
list[GenerateReturn]: list of generated responses.
"""
for request in requests:
request.stop_sequence = request.stop_sequence + [self.tokenizer.eos_token]
dataset = GenerativeTaskDataset(requests=requests, dataset_splits=dataset_splits)
request.tokenized_context = self.tok_encode(request.context)

dataset = GenerativeTaskDataset(requests=requests, dataset_splits=self.DATASET_SPLITS)
starting_batch_size = STARTING_BATCH_SIZE
results = []

for split_start, split_end in tqdm(
dataset.splits_start_end_iterator(),
total=DATASET_SPLITS,
total=self.DATASET_SPLITS,
desc="Splits",
position=0,
disable=self.disable_tqdm,
):
# Longest context in the current split is the first item (since we sort reversed)
longest_context_continuation_size_in_split = len(dataset[0].context) + dataset[0].generation_size
longest_context_continuation_size_in_split = len(dataset[0].tokenized_context) + dataset[0].generation_size
max_continuation_size_allowed = min(longest_context_continuation_size_in_split, self.max_length)
batch_size = self._get_batch_size(
override_bs=override_bs,
max_input_length=max_continuation_size_allowed,
starting_batch_size=starting_batch_size,
)
# For next iteration, since the batch will be smaller, we'll test a bigger batch size
starting_batch_size = batch_size * 2

dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda batch: batch)
if self.accelerator:
Expand Down Expand Up @@ -500,7 +473,9 @@ def _generate(
return all_responses

def loglikelihood(
self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None
self,
requests: list[LoglikelihoodRequest],
override_bs: Optional[int] = None,
) -> list[LoglikelihoodReturn]:
"""Tokenize the context and continuation and compute the log likelihood of those
tokenized sequences.
Expand All @@ -516,16 +491,17 @@ def loglikelihood(
request.tokenized_context = [self.tokenizer.eos_token_id]
request.tokenized_continuation = self.tok_encode(request.choice)
else:
# DO NOT CHANGE THE FOLLOWING LINE!
# It is mandatory for compatibility with the harness!!!
# The following line is mandatory for compatibility with the harness
request.tokenized_context, request.tokenized_continuation = self.tok_encode_pair(
request.context, request.choice
)

return self._loglikelihood_tokens(requests, override_bs=override_bs, dataset_splits=DATASET_SPLITS)
return self._loglikelihood_tokens(requests, override_bs=override_bs)

def loglikelihood_rolling(
self, requests: list[LoglikelihoodRollingRequest], override_bs=None
self,
requests: list[LoglikelihoodRollingRequest],
override_bs=None,
) -> list[LoglikelihoodReturn]:
"""This function is used to compute the log likelihood of the context for perplexity metrics."""

Expand All @@ -537,18 +513,16 @@ def loglikelihood_rolling(
requests,
override_bs=override_bs,
return_bool_score=False,
dataset_splits=DATASET_SPLITS,
)
return results

def _loglikelihood_tokens(
self,
requests: list[LoglikelihoodRequest],
override_bs: int = -1,
dataset_splits: int = 4,
return_bool_score: bool = True,
) -> list[LoglikelihoodReturn]:
dataset = LoglikelihoodDataset(requests=requests, dataset_splits=dataset_splits)
dataset = LoglikelihoodDataset(requests=requests, dataset_splits=self.DATASET_SPLITS)
starting_batch_size = STARTING_BATCH_SIZE
res = []

Expand Down Expand Up @@ -758,9 +732,9 @@ def loglikelihood_single_token(
return self._loglikelihood_single_token(requests, override_bs=override_bs)

def _loglikelihood_single_token(
self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: int = -1, dataset_splits: int = 4
self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: int = -1
) -> list[LoglikelihoodSingleTokenReturn]:
dataset = LoglikelihoodSingleTokenDataset(requests=requests, dataset_splits=dataset_splits)
dataset = LoglikelihoodSingleTokenDataset(requests=requests, dataset_splits=self.DATASET_SPLITS)
starting_batch_size = STARTING_BATCH_SIZE
res = []

Expand Down
Loading

0 comments on commit 9b0fe54

Please sign in to comment.