diff --git a/llmfoundry/eval/datasets/in_context_learning_evaluation.py b/llmfoundry/eval/datasets/in_context_learning_evaluation.py index df5799df2b..8f317f60b8 100644 --- a/llmfoundry/eval/datasets/in_context_learning_evaluation.py +++ b/llmfoundry/eval/datasets/in_context_learning_evaluation.py @@ -478,7 +478,8 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id) return batch - def split_batch(self, batch: Any, microbatch_size: Union[int , float]) -> Sequence[Any]: + def split_batch(self, batch: Any, + microbatch_size: Union[int, float]) -> Sequence[Any]: """Handling for certain specialty columns that must be split into. batches in different formats. @@ -494,7 +495,8 @@ def split_batch(self, batch: Any, microbatch_size: Union[int , float]) -> Sequen # Normally split torch tensors # List split lists of strings if isinstance(microbatch_size, float): - raise ValueError('split_batch does not support floating point microbatch_size.') + raise ValueError( + 'split_batch does not support floating point microbatch_size.') chunked = {} for k, v in batch.items(): if k in self.static_keys: @@ -907,7 +909,8 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: def get_num_samples_in_batch(self, batch: Dict[str, torch.Tensor]) -> int: return batch['input_ids'].shape[0] // self.num_choices - def split_batch(self, batch: Any, microbatch_size: Union[int , float]) -> Sequence[Any]: + def split_batch(self, batch: Any, + microbatch_size: Union[int, float]) -> Sequence[Any]: """Split batch while ensuring all continuations are in the same. microbatch. @@ -925,7 +928,8 @@ def split_batch(self, batch: Any, microbatch_size: Union[int , float]) -> Sequen list: List of chunked batches """ if isinstance(microbatch_size, float): - raise ValueError('split_batch does not support floating point microbatch_size.') + raise ValueError( + 'split_batch does not support floating point microbatch_size.') chunked = {} for k, v in batch.items(): if k in self.static_keys: