Skip to content

Commit

Permalink
pyright wip
Browse files Browse the repository at this point in the history
  • Loading branch information
maxisawesome committed Apr 10, 2024
1 parent 3c8ac56 commit 642ad40
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions llmfoundry/eval/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os
import random
import warnings
from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Sequence, Union

import torch
import transformers
Expand Down Expand Up @@ -478,8 +478,7 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id)
return batch

def split_batch(self, batch: Any,
microbatch_size: int) -> List[Dict[str, Any]]:
def split_batch(self, batch: Any, microbatch_size: int) -> Sequence:
"""Handling for certain specialty columns that must be split into.
batches in different formats.
Expand Down Expand Up @@ -906,8 +905,7 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
def get_num_samples_in_batch(self, batch: Dict[str, torch.Tensor]) -> int:
return batch['input_ids'].shape[0] // self.num_choices

def split_batch(self, batch: Any,
microbatch_size: int) -> List[Dict[str, Any]]:
def split_batch(self, batch: Any, microbatch_size: int) -> Sequence:
"""Split batch while ensuring all continuations are in the same.
microbatch.
Expand Down

0 comments on commit 642ad40

Please sign in to comment.