Skip to content

Commit

Permalink
fix pyright error again
Browse files Browse the repository at this point in the history
  • Loading branch information
bmosaicml committed Apr 11, 2024
1 parent de321b2 commit 019c58a
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions llmfoundry/eval/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,14 +478,14 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id)
return batch

def split_batch(self, batch: Any, microbatch_size: int) -> Sequence[Any]:
def split_batch(self, batch: Any, microbatch_size: Union[int , float]) -> Sequence[Any]:
"""Handling for certain specialty columns that must be split into.
batches in different formats.
Args:
batch (Dict): Batch of data
microbatch_size (int): Size of microbatches
microbatch_size (int | float): Size of microbatches
Returns:
List: List of chunked batches
Expand Down Expand Up @@ -905,7 +905,7 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
def get_num_samples_in_batch(self, batch: Dict[str, torch.Tensor]) -> int:
return batch['input_ids'].shape[0] // self.num_choices

def split_batch(self, batch: Any, microbatch_size: int) -> Sequence[Any]:
def split_batch(self, batch: Any, microbatch_size: Union[int , float]) -> Sequence[Any]:
"""Split batch while ensuring all continuations are in the same.
microbatch.
Expand All @@ -917,7 +917,7 @@ def split_batch(self, batch: Any, microbatch_size: int) -> Sequence[Any]:
microbatch_size and real attributes by microbatch_size * num_choices.
Args:
batch (Dict): Batch of data
microbatch_size (int): Size of microbatches
microbatch_size (int | float): Size of microbatches
Returns:
list: List of chunked batches
Expand Down

0 comments on commit 019c58a

Please sign in to comment.