Skip to content

Commit

Permalink
fix suite 1
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier committed Feb 7, 2024
1 parent e3cbe24 commit d9b262f
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 20 deletions.
7 changes: 3 additions & 4 deletions src/lighteval/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,8 @@ def _sorting_criteria(self, request: LoglikelihoodSingleTokenRequest) -> int:
automatic adaptive batches much much easier to implement
- any OOMs will happen right away rather than near the end
"""
toks = (
request.tokenized_context
) # We take only the prompt, no need for the continuation (since it's a list of single tokens)
# We take only the prompt, no need for the continuation (since it's a list of single tokens)
toks = request.tokenized_context
return -len(toks)


Expand All @@ -191,7 +190,7 @@ def _sorting_criteria(self, request: GreedyUntilRequest | GreedyUntilWithLogitsR
Returns:
Any: The collated data.
"""
toks = (request.context,)
toks = request.context
gen_length = request.generation_size
return -(len(toks) + gen_length)

Expand Down
16 changes: 8 additions & 8 deletions src/lighteval/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@


class BaseModel(LightevalModel):
# Default max sequence length setting for when no `max_length` is provided
# or no max length config setting is found in the model or tokenizer.
_DEFAULT_MAX_LENGTH: int = 2048

def __init__(
self,
config: BaseModelConfig,
Expand Down Expand Up @@ -239,7 +235,9 @@ def _init_max_length(self, max_length) -> int:

if hasattr(self.tokenizer, "model_max_length"):
return self.tokenizer.model_max_length
return self._DEFAULT_MAX_LENGTH
# Default max sequence length setting for when no `max_length` is provided
# or no max length config setting is found in the model or tokenizer.
return 2048

@property
def batch_size(self) -> int:
Expand Down Expand Up @@ -696,7 +694,8 @@ def prepare_batch(
raise ValueError("Negative padding")

padded.append(padding_length - sequence_len)
tokens = F.pad(tokens, (padding_length - sequence_len), value=0)
# Right padding - it likely would be better to do left padding
tokens = F.pad(tokens, (0, padding_length - sequence_len), value=0)

# We create the attention mask to ignore padding
mask = tokens == 0
Expand Down Expand Up @@ -782,10 +781,11 @@ def _loglikelihood_single_token(
dataloader = self.accelerator.prepare(dataloader)

for batch in tqdm(dataloader, disable=self.disable_tqdm, position=1):
prepared_batch = self.prepare_batch(batch, padding_length=max_context, max_context=max_context)
prepared_batch = self.prepare_batch(
batch, padding_length=max_context, max_context=max_context, single_token=True
)

out = self._model_call(prepared_batch.input_ids) # [batch, padding_length, vocab]

out = F.log_softmax(out, dim=-1) # we do a softmax over the options, no the vocab

batch_probs = []
Expand Down
52 changes: 44 additions & 8 deletions tests/test_unit_reorder.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,49 @@
from lighteval.data import GenerativeTaskDataset
from lighteval.tasks.requests import GreedyUntilRequest


# test data that will need to be sorted by length of the string
data = [
("1 The quick brown fox jumps over the lazy dog", ([":", "stop"], 10)),
("2 The quick brown fox jumps over the lazy dog njsa", ([":", "stop"], 10)),
("Some text", ([":", "stop"], 10)),
("some more text", ([":", "stop"], 10)),
("not sure what to write here", ([":", "stop"], 10)),
GreedyUntilRequest(
task_name="test",
example_index=0,
request_index=0,
context="1 The quick brown fox jumps over the lazy dog",
stop_sequence=[":", "stop"],
generation_size=10,
),
GreedyUntilRequest(
task_name="test",
example_index=2,
request_index=0,
context="2 The quick brown fox jumps over the lazy dog njsa",
stop_sequence=[":", "stop"],
generation_size=10,
),
GreedyUntilRequest(
task_name="test",
example_index=5,
request_index=0,
context="Some text",
stop_sequence=[":", "stop"],
generation_size=10,
),
GreedyUntilRequest(
task_name="test",
example_index=21,
request_index=0,
context="some more text",
stop_sequence=[":", "stop"],
generation_size=10,
),
GreedyUntilRequest(
task_name="test",
example_index=1,
request_index=0,
context="not sure what to write here",
stop_sequence=[":", "stop"],
generation_size=10,
),
]

DATASET_SPLITS = 1
Expand All @@ -21,9 +57,9 @@ def test_reorder_dataset(self):
original_data = dataset.get_original_order(sorted_data)

for i in range(len(sorted_data) - 1):
assert len(sorted_data[i][0]) >= len(
sorted_data[i + 1][0]
), f"dataset[{i}][0] = {sorted_data[i][0]} is shorter than dataset[{i+1}][0] = {sorted_data[i+1][0]}"
assert (
len(sorted_data[i].context) >= len(sorted_data[i + 1].context)
), f"dataset[{i}][0] = {sorted_data[i].context} is shorter than dataset[{i+1}][0] = {sorted_data[i+1].context}"

assert len(sorted_data) == len(
original_data
Expand Down

0 comments on commit d9b262f

Please sign in to comment.