Skip to content

Commit

Permalink
revert name + add import fix
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier committed Feb 7, 2024
1 parent dc454fd commit 827488e
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions src/lighteval/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
LoglikelihoodRequest,
LoglikelihoodRollingRequest,
LoglikelihoodSingleTokenRequest,
Request,
)
from lighteval.utils import (
is_accelerate_available,
Expand Down Expand Up @@ -357,10 +358,10 @@ def greedy_until(
):
# Longest context in the current split is the first item (since we sort reversed)
longest_context_continuation_size_in_split = len(dataset[0].tokenized_context) + dataset[0].generation_size
max_continuation_size_allowed = min(longest_context_continuation_size_in_split, self.max_length)
max_context_continuation_size_allowed = min(longest_context_continuation_size_in_split, self.max_length)
batch_size = self._get_batch_size(
override_bs=override_bs,
max_input_length=max_continuation_size_allowed,
max_input_length=max_context_continuation_size_allowed,
starting_batch_size=starting_batch_size,
)

Expand Down Expand Up @@ -529,11 +530,13 @@ def _loglikelihood_tokens(
for split_start, split_end in tqdm(dataset.splits_start_end_iterator()):
context_enc = dataset[0].tokenized_context
continuation_enc = dataset[0].tokenized_continuation
max_continuation_size_allowed = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1])
max_context_continuation_size_allowed = len(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
)

batch_size = self._get_batch_size(
override_bs=override_bs,
max_input_length=max_continuation_size_allowed,
max_input_length=max_context_continuation_size_allowed,
starting_batch_size=starting_batch_size,
)
starting_batch_size = batch_size * 2
Expand All @@ -544,7 +547,9 @@ def _loglikelihood_tokens(

for batch in tqdm(dataloader, disable=self.disable_tqdm):
prepared_batch = self.prepare_batch(
batch, padding_length=max_continuation_size_allowed, max_context=max_continuation_size_allowed
batch,
padding_length=max_context_continuation_size_allowed,
max_context=max_context_continuation_size_allowed,
)

model_output = self._model_call(prepared_batch.input_ids)
Expand Down

0 comments on commit 827488e

Please sign in to comment.