Skip to content

Commit

Permalink
Clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
sanjari-orb committed Jun 11, 2024
1 parent 2332461 commit eb9ac84
Showing 1 changed file with 11 additions and 31 deletions.
42 changes: 11 additions & 31 deletions llmfoundry/eval/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,8 @@ def __len__(self) -> int:
def get_num_samples_in_batch(self, batch: Dict) -> int:
return batch['input_ids'].shape[0]

def get_effective_batch_size(self) -> int:
return NotImplementedError(
"Calculation for effective batch size not implemented"
)
def get_effective_batch_size(self, batch_size: int) -> int:
return batch_size

def update_generation_kwargs(self, generation_kwargs: Dict) -> None:
r"""Updates self.base_batch with the passed in generation_kwargs.
Expand Down Expand Up @@ -528,8 +526,11 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id)
return batch

def split_batch(self, batch: Any,
microbatch_size: Union[int, float]) -> Sequence[Any]:
def split_batch(
self,
batch: Any,
microbatch_size: Union[int, float]
) -> Sequence[Any]:
return _default_split_batch(batch, microbatch_size)


Expand Down Expand Up @@ -570,17 +571,12 @@ def __init__(
context_key: str = 'context',
answer_key: str = 'answer',
strip_dataset: bool = True,
# padding_side: str = 'right',
# tokenize_labels: bool = True,
padding_size: Optional[int] = None,
base_batch: Optional[Dict] = None,
batch_mapping: Optional[Dict] = None,
hf_loading_vars: Optional[Dict] = None,
hf_parsing_map: Optional[Dict] = None,
generation_kwargs: Optional[Dict] = None,
# static_keys: Optional[List] = None,
# list_keys: Optional[List] = None,
# tensor_keys: Optional[List] = None,
cot_delimiter: str = '',
early_stopping_criteria: Optional[List[str]] = None,
do_normalization: bool = True,
Expand All @@ -604,7 +600,6 @@ def __init__(
tensor_keys = ['input_ids', 'attention_mask']
list_keys = ['labels']
super().__init__(
# Super class
dataset_uri=dataset_uri,
tokenizer=tokenizer,
max_seq_len=max_seq_len,
Expand All @@ -625,7 +620,7 @@ def __init__(
hf_loading_vars=hf_loading_vars,
hf_parsing_map=hf_parsing_map,
generation_kwargs=generation_kwargs,
# Custom
# specific to ICL dataset
padding_side='left',
tokenize_labels=False,
static_keys=static_keys,
Expand Down Expand Up @@ -657,9 +652,6 @@ def __init__(
if 'generation_kwargs' in kwargs:
self.update_generation_kwargs(kwargs['generation_kwargs'])

def get_effective_batch_size(self, batch_size: int) -> int:
return batch_size

def read_dataset(
self,
dataset_uri: str,
Expand Down Expand Up @@ -855,24 +847,18 @@ def __init__(
destination_path: str,
prelimiter: str = '',
context_key: str = 'context',
# answer_key: str = 'answer',
strip_dataset: bool = True,
# padding_side: str = 'right',
tokenize_labels: bool = True,
padding_size: Optional[int] = None,
# base_batch: Optional[Dict] = None,
# batch_mapping: Optional[Dict] = None,
hf_loading_vars: Optional[Dict] = None,
hf_parsing_map: Optional[Dict] = None,
generation_kwargs: Optional[Dict] = None,
static_keys: Optional[List] = None,
list_keys: Optional[List] = None,
# tensor_keys: Optional[List] = None,
*args: Any,
**kwargs: Any,
):
super().__init__(
# Super class
dataset_uri=dataset_uri,
tokenizer=tokenizer,
max_seq_len=max_seq_len,
Expand All @@ -888,12 +874,11 @@ def __init__(
strip_dataset=strip_dataset,
tokenize_labels=tokenize_labels,
padding_size=padding_size,
# base_batch=base_batch,
# batch_mapping=batch_mapping,
hf_loading_vars=hf_loading_vars,
hf_parsing_map=hf_parsing_map,
generation_kwargs=generation_kwargs,
list_keys=list_keys,
# specific to ICL dataset
answer_key='continuation',
static_keys=['mode'],
tensor_keys=[
Expand All @@ -917,8 +902,6 @@ def __init__(
**kwargs,
)

def get_effective_batch_size(self, batch_size: int) -> int:
return batch_size

class InContextLearningMultipleChoiceTaskDataset(InContextLearningDataset):
"""A dataset that construct batches for in-context learning multiple choice.
Expand Down Expand Up @@ -997,7 +980,6 @@ def __init__(
self.list_of_tuples_keys = list_of_tuples_keys or ['choice_groupings']
self.list_of_primitives = list_of_primitives or ['gold_indices']
super().__init__(
# Super
dataset_uri=dataset_uri,
tokenizer=tokenizer,
max_seq_len=max_seq_len,
Expand All @@ -1018,7 +1000,7 @@ def __init__(
hf_parsing_map=hf_parsing_map,
generation_kwargs=generation_kwargs,
list_keys=list_keys,
# Custom
# specific to ICL dataset
context_key=context_key,
base_batch=base_batch,
static_keys=static_keys,
Expand Down Expand Up @@ -1268,7 +1250,6 @@ def __init__(
continuation_delimiter: str,
destination_path: str,
prelimiter: str = '',
context_key: str = 'context',
answer_key: str = 'answer',
strip_dataset: bool = True,
padding_side: str = 'right',
Expand Down Expand Up @@ -1308,7 +1289,7 @@ def __init__(
hf_parsing_map=hf_parsing_map,
generation_kwargs=generation_kwargs,
list_keys=list_keys,
# Custom class
# specific to ICL dataset
choices_key=choices_key,
context_key=choices_key,
static_keys=static_keys,
Expand Down Expand Up @@ -1523,7 +1504,6 @@ def build_icl_dataloader(
do_normalization: bool = True,
) -> DataSpec:
"""Factory method that builds the specific dataset for the specified.
icl_task_type. See documentation for `get_icl_task_dataloader` for argument
documentation.
Expand Down

0 comments on commit eb9ac84

Please sign in to comment.