Skip to content

Commit

Permalink
fix pyright
Browse files Browse the repository at this point in the history
  • Loading branch information
bmosaicml committed Apr 11, 2024
1 parent 779f490 commit 03f7e91
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions llmfoundry/eval/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,8 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id)
return batch

def split_batch(self, batch: Any, microbatch_size: Union[int , float]) -> Sequence[Any]:
def split_batch(self, batch: Any,
microbatch_size: Union[int, float]) -> Sequence[Any]:
"""Handling for certain specialty columns that must be split into.
batches in different formats.
Expand All @@ -494,7 +495,8 @@ def split_batch(self, batch: Any, microbatch_size: Union[int , float]) -> Sequen
# Normally split torch tensors
# List split lists of strings
if isinstance(microbatch_size, float):
raise ValueError('split_batch does not support floating point microbatch_size.')
raise ValueError(
'split_batch does not support floating point microbatch_size.')
chunked = {}
for k, v in batch.items():
if k in self.static_keys:
Expand Down Expand Up @@ -907,7 +909,8 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
def get_num_samples_in_batch(self, batch: Dict[str, torch.Tensor]) -> int:
return batch['input_ids'].shape[0] // self.num_choices

def split_batch(self, batch: Any, microbatch_size: Union[int , float]) -> Sequence[Any]:
def split_batch(self, batch: Any,
microbatch_size: Union[int, float]) -> Sequence[Any]:
"""Split batch while ensuring all continuations are in the same.
microbatch.
Expand All @@ -925,7 +928,8 @@ def split_batch(self, batch: Any, microbatch_size: Union[int , float]) -> Sequen
list: List of chunked batches
"""
if isinstance(microbatch_size, float):
raise ValueError('split_batch does not support floating point microbatch_size.')
raise ValueError(
'split_batch does not support floating point microbatch_size.')
chunked = {}
for k, v in batch.items():
if k in self.static_keys:
Expand Down

0 comments on commit 03f7e91

Please sign in to comment.