From 642ad40a1dbed0bb17a100528191f2cb80be3e98 Mon Sep 17 00:00:00 2001 From: Max Marion Date: Wed, 10 Apr 2024 22:32:38 +0000 Subject: [PATCH] pyright wip --- .../eval/datasets/in_context_learning_evaluation.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/llmfoundry/eval/datasets/in_context_learning_evaluation.py b/llmfoundry/eval/datasets/in_context_learning_evaluation.py index c582be798e..bd5c7dc30c 100644 --- a/llmfoundry/eval/datasets/in_context_learning_evaluation.py +++ b/llmfoundry/eval/datasets/in_context_learning_evaluation.py @@ -9,7 +9,7 @@ import os import random import warnings -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Sequence, Union import torch import transformers @@ -478,8 +478,7 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id) return batch - def split_batch(self, batch: Any, - microbatch_size: int) -> List[Dict[str, Any]]: + def split_batch(self, batch: Any, microbatch_size: int) -> Sequence: """Handling for certain specialty columns that must be split into. batches in different formats. @@ -906,8 +905,7 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: def get_num_samples_in_batch(self, batch: Dict[str, torch.Tensor]) -> int: return batch['input_ids'].shape[0] // self.num_choices - def split_batch(self, batch: Any, - microbatch_size: int) -> List[Dict[str, Any]]: + def split_batch(self, batch: Any, microbatch_size: int) -> Sequence: """Split batch while ensuring all continuations are in the same. microbatch.