Skip to content

Commit

Permalink
fix split_batch in generation_task_with_answers
Browse files Browse the repository at this point in the history
  • Loading branch information
sanjari-orb committed Jun 11, 2024
1 parent 87bba14 commit 2332461
Showing 1 changed file with 42 additions and 40 deletions.
82 changes: 42 additions & 40 deletions llmfoundry/eval/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,8 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
def split_batch(self, batch: Any,
microbatch_size: Union[int, float]) -> Sequence[Any]:
return _default_split_batch(batch, microbatch_size)


class InContextLearningGenerationTaskWithAnswersDataset(
InContextLearningDataset,
):
Expand Down Expand Up @@ -785,6 +787,46 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
batch['generation_kwargs']['stopping_criteria'] = stopping_criteria
return batch

def split_batch(self, batch: Any,
microbatch_size: Union[int, float]) -> Sequence[Any]:
"""Handling for certain specialty columns that must be split into.
batches in different formats.
Args:
batch (Dict): Batch of data
microbatch_size (int | float): Size of microbatches
Returns:
List: List of chunked batches
"""
# Don't split kwargs that don't change
# Normally split torch tensors
# List split lists of strings
if isinstance(microbatch_size, float):
raise ValueError(
'split_batch does not support floating point microbatch_size.',
)
chunked = {}
for k, v in batch.items():
if k in self.static_keys:
# Defer broadcasting until we know num_chunks
pass
elif k in self.list_keys:
chunked[k] = _split_list(v, microbatch_size)
elif k in self.tensor_keys:
chunked[k] = _default_split_batch(v, microbatch_size)
else:
raise ValueError(f'Unexpected key {k} in batch splitting')
num_chunks = len(chunked['input_ids'])
for k, v in batch.items():
if k in self.static_keys:
chunked[k] = [v] * num_chunks

batched_list = [{k: v[idx]
for k, v in chunked.items()}
for idx in range(num_chunks)]
return batched_list

class InContextLearningLMTaskDataset(InContextLearningDataset):
"""A dataset that constructs batches for in-context learning language.
Expand Down Expand Up @@ -1457,46 +1499,6 @@ def tokenize_example(
tokenized_example['gold'] = example['gold']
return tokenized_example

def split_batch(self, batch: Any,
microbatch_size: Union[int, float]) -> Sequence[Any]:
"""Handling for certain specialty columns that must be split into.
batches in different formats.
Args:
batch (Dict): Batch of data
microbatch_size (int | float): Size of microbatches
Returns:
List: List of chunked batches
"""
# Don't split kwargs that don't change
# Normally split torch tensors
# List split lists of strings
if isinstance(microbatch_size, float):
raise ValueError(
'split_batch does not support floating point microbatch_size.',
)
chunked = {}
for k, v in batch.items():
if k in self.static_keys:
# Defer broadcasting until we know num_chunks
pass
elif k in self.list_keys:
chunked[k] = _split_list(v, microbatch_size)
elif k in self.tensor_keys:
chunked[k] = _default_split_batch(v, microbatch_size)
else:
raise ValueError(f'Unexpected key {k} in batch splitting')
num_chunks = len(chunked['input_ids'])
for k, v in batch.items():
if k in self.static_keys:
chunked[k] = [v] * num_chunks

batched_list = [{k: v[idx]
for k, v in chunked.items()}
for idx in range(num_chunks)]
return batched_list


def build_icl_dataloader(
Expand Down

0 comments on commit 2332461

Please sign in to comment.