Skip to content

Commit

Permalink
Merge branch 'main' into nathan-add-cli-tool
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier authored Jul 11, 2024
2 parents ebb5611 + 66e6aae commit 035aab7
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 52 deletions.
92 changes: 62 additions & 30 deletions src/lighteval/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ class DynamicBatchDataset(Dataset):
def __init__(
self,
requests: list,
dataset_splits: int,
num_dataset_splits: int,
):
"""
This dataset class uses dynamic batching to speed up the generation.
Each request is sorted by the length of the prompt + the length of the
continuation. Then, the dataset is split into dataset_splits splits.
continuation. Then, the dataset is split into num_dataset_splits splits.
The first split will contain the longest requests, the second split will
contain the second longest requests, etc. This allows us to use dynamic
batching by starting with a small batch size and doubling it for each
Expand All @@ -54,7 +54,7 @@ def __init__(
Args:
requests (List): A list of requests.
dataset_splits (int): The number of dataset splits.
num_dataset_splits (int): The number of dataset splits.
"""
# We make sure the requests contain the tokenized versions of their values
if any(r.tokenized_context is None for r in requests):
Expand All @@ -69,16 +69,23 @@ def __init__(

self.total_size = len(self.sorted_data)

if dataset_splits >= self.total_size:
self.num_dataset_splits, self.splits = self.init_split_limits(num_dataset_splits)

self.split_start, self.split_end = self.splits[0]

def init_split_limits(self, num_dataset_splits):
if num_dataset_splits >= self.total_size:
hlog_warn(
f"dataset_splits ({dataset_splits}) >= total_size ({self.total_size}), setting dataset_splits to 1"
f"num_dataset_splits ({num_dataset_splits}) >= total_size ({self.total_size}), setting num_dataset_splits to 1"
)
dataset_splits = 1
num_dataset_splits = 1

self.dataset_splits = dataset_splits
self.split_size = self.total_size // self.dataset_splits + 1
self.split_start = 0
self.split_end = min(self.split_start + self.split_size, self.total_size)
split_size = self.total_size // num_dataset_splits + 1
splits_indices = [
(ix * split_size, min((ix + 1) * split_size, self.total_size)) for ix in range(num_dataset_splits)
]

return num_dataset_splits, splits_indices

def get_original_order(self, new_arr: list) -> list:
"""
Expand Down Expand Up @@ -113,8 +120,7 @@ def get_split_start_end(self, split_id: int) -> tuple[int, int]:
Returns:
tuple: A tuple containing the start and end indices of the split.
"""
self.split_start = split_id * self.split_size
self.split_end = min(self.split_start + self.split_size, self.total_size)
self.split_start, self.split_end = self.splits[split_id]
return self.split_start, self.split_end

def splits_start_end_iterator(self) -> tuple[int, int]:
Expand All @@ -126,7 +132,7 @@ def splits_start_end_iterator(self) -> tuple[int, int]:
Yields:
tuple: A tuple containing the start and end indices of a split.
"""
for split_id in range(self.dataset_splits):
for split_id in range(self.num_dataset_splits):
yield self.get_split_start_end(split_id)

def __getitem__(self, index) -> Request:
Expand Down Expand Up @@ -204,7 +210,47 @@ def _sorting_criteria(self, request: LoglikelihoodSingleTokenRequest) -> int:


class GenerativeTaskDataset(DynamicBatchDataset):
def _sorting_criteria(self, request: GreedyUntilRequest) -> int:
def init_split_limits(self, num_dataset_splits):
"""Initialises the split limits based on generation parameters.
The splits are used to estimate time remaining when evaluating, and in the case of generative evaluations, to group similar samples together.
For generative tasks, self._sorting_criteria outputs:
- a boolean (whether the generation task uses logits)
- a list (the stop sequences)
- the item length (the actual size sorting factor).
In the current function, we create evaluation groups by generation parameters (logits and eos), so that samples with similar properties get batched together afterwards.
The samples will then be further organised by length in each split.
Args:
num_dataset_splits (_type_): _description_
Returns:
_type_: _description_
"""
if num_dataset_splits is not None:
hlog_warn(
"You cannot select the number of dataset splits for a generative evaluation at the moment. Automatically inferring."
)

all_sorting_criterion = [self._sorting_criteria(self.sorted_data[0])[:2]]
splits_indices = [[0, None]]
for ix, req in enumerate(self.sorted_data):
current_sorting_criteria = self._sorting_criteria(req)
current_key = current_sorting_criteria[:2]
if current_key not in all_sorting_criterion:
all_sorting_criterion.append(current_key)
splits_indices[-1][1] = ix
splits_indices.append([ix, None])

# We add the last split
splits_indices[-1][1] = self.total_size

num_dataset_splits = len(splits_indices)
splits_indices = [tuple(e) for e in splits_indices]
return num_dataset_splits, splits_indices

def _sorting_criteria(self, request: GreedyUntilRequest) -> tuple[bool, list, int]:
"""
Collate function for generating batches.
Expand All @@ -219,10 +265,10 @@ def _sorting_criteria(self, request: GreedyUntilRequest) -> int:
# The generative task has no limit except the model context
if gen_length is None:
gen_length = 0
return -(len(toks) + gen_length)
return request.use_logits, request.stop_sequence, -(len(toks) + gen_length)


class GenerativeTaskDatasetNanotron(DynamicBatchDataset):
class GenerativeTaskDatasetNanotron(GenerativeTaskDataset):
def __getitem__(self, index) -> Request:
"""
Get an item from the dataset depending on the split we are currently in.
Expand All @@ -238,20 +284,6 @@ def __getitem__(self, index) -> Request:
"""
return index, self.sorted_data[index + self.split_start]

def _sorting_criteria(self, request) -> int:
"""
Collate function for generating batches.
Args:
x (Any): The input data.
Returns:
Any: The collated data.
"""
toks = request.tokenized_context
gen_length = request.generation_size
return -(len(toks) + gen_length)


class GenDistributedSampler(DistributedSampler):
"""A distributed sampler that copy the last element only when drop_last is False so we keep a small padding in the batches
Expand Down
4 changes: 2 additions & 2 deletions src/lighteval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,15 @@ def apply_generative_metric(
outputs.update(
Metrics[metric].value.compute(
golds=golds,
predictions=as_list(preds[0]) if max_num_samples > 0 else preds,
predictions=as_list(preds[0]) if max_num_samples > 1 else preds,
formatted_doc=formatted_doc,
)
)
if Metrics[metric].value.category == MetricCategory.GENERATIVE_LOGPROB:
outputs.update(
Metrics[metric].value.compute(
golds=golds,
predictions=as_list(preds[0]) if max_num_samples > 0 else preds,
predictions=as_list(preds[0]) if max_num_samples > 1 else preds,
formatted_doc=formatted_doc,
)
)
Expand Down
10 changes: 5 additions & 5 deletions src/lighteval/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def greedy_until_multi_turn( # noqa: C901

results = []

dataset = GenerativeTaskDataset(requests=requests, dataset_splits=1)
dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=1)
dataloader = DataLoader(dataset, batch_size=1, collate_fn=lambda batch: batch)

if self.accelerator:
Expand Down Expand Up @@ -480,13 +480,13 @@ def greedy_until(
request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token]
request.tokenized_context = self.tok_encode(request.context)

dataset = GenerativeTaskDataset(requests=requests, dataset_splits=self.DATASET_SPLITS)
dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
starting_batch_size = STARTING_BATCH_SIZE
results = []

for split_start, split_end in tqdm(
dataset.splits_start_end_iterator(),
total=self.DATASET_SPLITS,
total=dataset.num_dataset_splits,
desc="Splits",
position=0,
disable=self.disable_tqdm,
Expand Down Expand Up @@ -715,7 +715,7 @@ def _loglikelihood_tokens(
return_bool_score: bool = True,
rolling: bool = False,
) -> list[LoglikelihoodReturn]:
dataset = LoglikelihoodDataset(requests=requests, dataset_splits=self.DATASET_SPLITS)
dataset = LoglikelihoodDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
starting_batch_size = STARTING_BATCH_SIZE
res = []

Expand Down Expand Up @@ -957,7 +957,7 @@ def loglikelihood_single_token(
def _loglikelihood_single_token(
self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: int = -1
) -> list[LoglikelihoodSingleTokenReturn]:
dataset = LoglikelihoodSingleTokenDataset(requests=requests, dataset_splits=self.DATASET_SPLITS)
dataset = LoglikelihoodSingleTokenDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
starting_batch_size = STARTING_BATCH_SIZE
res = []

Expand Down
6 changes: 3 additions & 3 deletions src/lighteval/models/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def greedy_until(
request.tokenized_context = self.tok_encode(request.context)
request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token]

dataset = GenerativeTaskDataset(requests=requests, dataset_splits=self.DATASET_SPLITS)
dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
batch_size = override_bs if override_bs is not None else BATCH_SIZE
results: List[str] = []

Expand Down Expand Up @@ -289,7 +289,7 @@ def loglikelihood(
for request in requests:
request.tokenized_context = self.tok_encode(request.context)
request.tokenized_continuation = self.tok_encode(request.choice)
dataset = LoglikelihoodDataset(requests=requests, dataset_splits=self.DATASET_SPLITS)
dataset = LoglikelihoodDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
batch_size = override_bs if override_bs is not None else BATCH_SIZE
results: List[str] = []

Expand Down Expand Up @@ -335,7 +335,7 @@ def loglikelihood_rolling(
request.tokenized_context = [self.tokenizer.eos_token_id]
request.tokenized_continuation = self.tok_encode(request.context)

dataset = LoglikelihoodDataset(requests=requests, dataset_splits=self.DATASET_SPLITS)
dataset = LoglikelihoodDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
batch_size = override_bs if override_bs is not None else BATCH_SIZE
results: List[str] = []

Expand Down
20 changes: 10 additions & 10 deletions src/lighteval/models/nanotron_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,17 +643,17 @@ def pad_and_gather(self, output_tensor: torch.Tensor) -> Tuple[torch.Tensor, tor

return gathered_outputs, gathered_length

def _get_subsets(self, dataset, dataset_splits):
def _get_subsets(self, dataset, num_dataset_splits):
total_length = len(dataset)
subset_length = int(float(total_length) / float(dataset_splits)) + 1
subset_length = int(float(total_length) / float(num_dataset_splits)) + 1
if subset_length < self.parallel_context.dp_pg.size():
# We need at least one subset sample per DP process
subset_length = self.parallel_context.dp_pg.size()
return total_length, subset_length

@torch.inference_mode()
def _loglikelihood_single_token(
self, requests, disable_tqdm: bool = False, override_bs: int = -1, dataset_splits: int = 1
self, requests, disable_tqdm: bool = False, override_bs: int = -1, num_dataset_splits: int = 1
) -> List[LoglikelihoodSingleTokenReturn]:
dataset = LoglikelihoodSingleTokenDataset(requests=requests)
res = []
Expand All @@ -663,7 +663,7 @@ def _loglikelihood_single_token(
printed_error = False
starting_batch_size = 512

total_length, subset_length = self._get_subsets(dataset, dataset_splits)
total_length, subset_length = self._get_subsets(dataset, num_dataset_splits)

for s, subset_start in enumerate(
tqdm(
Expand Down Expand Up @@ -883,17 +883,17 @@ def _loglikelihood_tokens(
requests,
disable_tqdm: bool = False,
override_bs: int = -1,
dataset_splits: int = 1,
num_dataset_splits: int = 1,
return_bool_score: bool = True,
) -> List[LoglikelihoodReturn]:
dataset = LoglikelihoodDataset(requests=requests, dataset_splits=dataset_splits)
dataset = LoglikelihoodDataset(requests=requests, num_dataset_splits=num_dataset_splits)
res = []

# Dataset is sorted in descending size.
# every 20-25% of the dataset we try to double the batch size for speed up
starting_batch_size = 512

total_length, subset_length = self._get_subsets(dataset, dataset_splits)
total_length, subset_length = self._get_subsets(dataset, num_dataset_splits)

for s, subset_start in enumerate(
tqdm(
Expand Down Expand Up @@ -1117,7 +1117,7 @@ def greedy_until(
requests: List[GreedyUntilRequest],
disable_tqdm: bool = False,
override_bs=None,
dataset_splits: int = 1,
num_dataset_splits: int = 1,
) -> List[GenerateReturn]:
"""Greedy generation until a stop token is generated."""
# automatic (variable) batch size detection for vectorization
Expand All @@ -1126,14 +1126,14 @@ def greedy_until(
request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token]
request.tokenized_context = self.tok_encode(request.context)

dataset = GenerativeTaskDatasetNanotron(requests=requests, dataset_splits=dataset_splits)
dataset = GenerativeTaskDatasetNanotron(requests=requests, num_dataset_splits=num_dataset_splits)
res = []

# Dataset is sorted in descending size.
# every 20-25% of the dataset we try to double the batch size for speed up
starting_batch_size = 512

total_length, subset_length = self._get_subsets(dataset, dataset_splits)
total_length, subset_length = self._get_subsets(dataset, num_dataset_splits)

for s, subset_start in enumerate(
tqdm(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_unit_reorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,15 @@
class TestReorderGenerativeTaskDataset:
def test_dataset_needs_tokenization(self):
with pytest.raises(ValueError):
GenerativeTaskDataset(requests=TEST_DATA, dataset_splits=DATASET_SPLITS)
GenerativeTaskDataset(requests=TEST_DATA, num_dataset_splits=DATASET_SPLITS)

def test_reorder_dataset(self):
tokenizer = AutoTokenizer.from_pretrained("gpt2")
data = TEST_DATA.copy()
for request in data:
request.tokenized_context = tokenizer.encode(request.context)

dataset = GenerativeTaskDataset(requests=data, dataset_splits=DATASET_SPLITS)
dataset = GenerativeTaskDataset(requests=data, num_dataset_splits=DATASET_SPLITS)

sorted_data = dataset.sorted_data
original_data = dataset.get_original_order(sorted_data)
Expand Down

0 comments on commit 035aab7

Please sign in to comment.