diff --git a/src/lighteval/models/base_model.py b/src/lighteval/models/base_model.py index 6b67e4c0..616bd0b3 100644 --- a/src/lighteval/models/base_model.py +++ b/src/lighteval/models/base_model.py @@ -20,6 +20,7 @@ LoglikelihoodRequest, LoglikelihoodRollingRequest, LoglikelihoodSingleTokenRequest, + Request, ) from lighteval.utils import ( is_accelerate_available, @@ -357,10 +358,10 @@ def greedy_until( ): # Longest context in the current split is the first item (since we sort reversed) longest_context_continuation_size_in_split = len(dataset[0].tokenized_context) + dataset[0].generation_size - max_continuation_size_allowed = min(longest_context_continuation_size_in_split, self.max_length) + max_context_continuation_size_allowed = min(longest_context_continuation_size_in_split, self.max_length) batch_size = self._get_batch_size( override_bs=override_bs, - max_input_length=max_continuation_size_allowed, + max_input_length=max_context_continuation_size_allowed, starting_batch_size=starting_batch_size, ) @@ -529,11 +530,13 @@ def _loglikelihood_tokens( for split_start, split_end in tqdm(dataset.splits_start_end_iterator()): context_enc = dataset[0].tokenized_context continuation_enc = dataset[0].tokenized_continuation - max_continuation_size_allowed = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]) + max_context_continuation_size_allowed = len( + (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1] + ) batch_size = self._get_batch_size( override_bs=override_bs, - max_input_length=max_continuation_size_allowed, + max_input_length=max_context_continuation_size_allowed, starting_batch_size=starting_batch_size, ) starting_batch_size = batch_size * 2 @@ -544,7 +547,9 @@ def _loglikelihood_tokens( for batch in tqdm(dataloader, disable=self.disable_tqdm): prepared_batch = self.prepare_batch( - batch, padding_length=max_continuation_size_allowed, max_context=max_continuation_size_allowed + batch, + padding_length=max_context_continuation_size_allowed, + max_context=max_context_continuation_size_allowed, ) model_output = self._model_call(prepared_batch.input_ids)