Skip to content

Commit

Permalink
Considering the case empty request list is given to base model (#250)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: Nathan Habib <[email protected]>
Co-authored-by: Clémentine Fourrier <[email protected]>
  • Loading branch information
3 people authored Nov 22, 2024
1 parent 24d5feb commit 2c9bf97
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 6 deletions.
17 changes: 11 additions & 6 deletions src/lighteval/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import Iterator
import math
from typing import Iterator, Tuple

import torch
from torch.utils.data import Dataset
Expand Down Expand Up @@ -80,7 +81,7 @@ def init_split_limits(self, num_dataset_splits):
)
num_dataset_splits = 1

split_size = self.total_size // num_dataset_splits + 1
split_size = math.ceil(self.total_size / num_dataset_splits)
splits_indices = [
(ix * split_size, min((ix + 1) * split_size, self.total_size)) for ix in range(num_dataset_splits)
]
Expand Down Expand Up @@ -110,7 +111,7 @@ def get_original_order(self, new_arr: list) -> list:

return original_order

def get_split_start_end(self, split_id: int) -> tuple[int, int]:
def get_split_start_end(self, split_id: int) -> Tuple[int, int]:
"""
Get the start and end indices of a dataset split.
Expand All @@ -123,7 +124,7 @@ def get_split_start_end(self, split_id: int) -> tuple[int, int]:
self.split_start, self.split_end = self.splits[split_id]
return self.split_start, self.split_end

def splits_start_end_iterator(self) -> tuple[int, int]:
def splits_start_end_iterator(self) -> Iterator[Tuple[int, int]]:
"""
Iterator that yields the start and end indices of each dataset split.
Also updates the starting batch size for each split (trying to double
Expand All @@ -132,7 +133,10 @@ def splits_start_end_iterator(self) -> tuple[int, int]:
Yields:
tuple: A tuple containing the start and end indices of a split.
"""
for split_id in range(self.num_dataset_splits):
split_range = self.num_dataset_splits
if self.total_size == 0:
split_range = 0
for split_id in range(split_range):
yield self.get_split_start_end(split_id)

def __getitem__(self, index) -> Request:
Expand Down Expand Up @@ -247,7 +251,8 @@ def init_split_limits(self, num_dataset_splits):
"You cannot select the number of dataset splits for a generative evaluation at the moment. Automatically inferring."
)

all_sorting_criterion = [self._sorting_criteria(self.sorted_data[0])[:2]]
if len(self.sorted_data) > 0:
all_sorting_criterion = [self._sorting_criteria(self.sorted_data[0])[:2]]
splits_indices = [[0, None]]
for ix, req in enumerate(self.sorted_data):
current_sorting_criteria = self._sorting_criteria(req)
Expand Down
37 changes: 37 additions & 0 deletions tests/models/test_base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# MIT License

# Copyright (c) 2024 The HuggingFace Team

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from lighteval.models.base_model import BaseModel
from lighteval.models.model_config import BaseModelConfig
from lighteval.models.model_loader import load_model
from lighteval.utils.utils import EnvConfig


def test_empty_requests():
model_config = BaseModelConfig("hf-internal-testing/tiny-random-LlamaForCausalLM")
model: BaseModel = load_model(config=model_config, env_config=EnvConfig(cache_dir="."))

assert model.loglikelihood([]) == []
assert model.loglikelihood_single_token([]) == []
assert model.loglikelihood_rolling([]) == []
assert model.greedy_until([]) == []
assert model.greedy_until_multi_turn([]) == []

0 comments on commit 2c9bf97

Please sign in to comment.