From 019c58a44c44bb2abb4751c7f4cdb9ac00702280 Mon Sep 17 00:00:00 2001 From: Jeremy Dohmann Date: Thu, 11 Apr 2024 19:32:09 -0400 Subject: [PATCH] fix pyright error again --- .../eval/datasets/in_context_learning_evaluation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llmfoundry/eval/datasets/in_context_learning_evaluation.py b/llmfoundry/eval/datasets/in_context_learning_evaluation.py index 55b21a1fa0..30502d2d92 100644 --- a/llmfoundry/eval/datasets/in_context_learning_evaluation.py +++ b/llmfoundry/eval/datasets/in_context_learning_evaluation.py @@ -478,14 +478,14 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id) return batch - def split_batch(self, batch: Any, microbatch_size: int) -> Sequence[Any]: + def split_batch(self, batch: Any, microbatch_size: Union[int , float]) -> Sequence[Any]: """Handling for certain specialty columns that must be split into. batches in different formats. Args: batch (Dict): Batch of data - microbatch_size (int): Size of microbatches + microbatch_size (int | float): Size of microbatches Returns: List: List of chunked batches @@ -905,7 +905,7 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: def get_num_samples_in_batch(self, batch: Dict[str, torch.Tensor]) -> int: return batch['input_ids'].shape[0] // self.num_choices - def split_batch(self, batch: Any, microbatch_size: int) -> Sequence[Any]: + def split_batch(self, batch: Any, microbatch_size: Union[int , float]) -> Sequence[Any]: """Split batch while ensuring all continuations are in the same. microbatch. @@ -917,7 +917,7 @@ def split_batch(self, batch: Any, microbatch_size: int) -> Sequence[Any]: microbatch_size and real attributes by microbatch_size * num_choices. Args: batch (Dict): Batch of data - microbatch_size (int): Size of microbatches + microbatch_size (int | float): Size of microbatches Returns: list: List of chunked batches