Skip to content

Commit

Permalink
Comments, still need to fix lint errors
Browse files Browse the repository at this point in the history
  • Loading branch information
sanjari-orb committed Jun 12, 2024
1 parent 1284a59 commit 1e37fd7
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 162 deletions.
182 changes: 54 additions & 128 deletions llmfoundry/eval/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ def __init__(
max_seq_len: int,
pad_tok_id: int,
num_fewshot: int,
fewshot_random_seed: int,
prompt_string: str,
example_delimiter: str,
continuation_delimiter: str,
destination_path: str,
fewshot_random_seed: int = 1234,
prompt_string: str = '',
example_delimiter: str = '\n',
continuation_delimiter: str = ' ',
prelimiter: str = '',
context_key: str = 'context',
answer_key: str = 'answer',
Expand All @@ -136,8 +136,6 @@ def __init__(
static_keys: Optional[List] = None,
list_keys: Optional[List] = None,
tensor_keys: Optional[List] = None,
*args: Any,
**kwargs: Any,
):
self.tokenizer = tokenizer
self.prefix_space = tokenizer_needs_prefix_space(self.tokenizer)
Expand Down Expand Up @@ -194,6 +192,17 @@ def get_num_samples_in_batch(self, batch: Dict) -> int:
return batch['input_ids'].shape[0]

def get_effective_batch_size(self, batch_size: int) -> int:
r"""Returns effective batch size computed for given ICL task.
The effective batch size may not be equal to the configured evaluation
batch size because for certain ICL tasks, >1 prompts can get created
for every input query depending on the number of choices/continuations.
This requires the effective batch size to be reduced to prevent larger batches than expected during eval. For example,
check InContextLearningMultipleChoiceTaskDataset.
Args:
batch_size (int): Original batch size configured for ICL evaluations
"""
return batch_size

def update_generation_kwargs(self, generation_kwargs: Dict) -> None:
Expand Down Expand Up @@ -562,11 +571,11 @@ def __init__(
max_seq_len: int,
pad_tok_id: int,
num_fewshot: int,
fewshot_random_seed: int,
prompt_string: str,
example_delimiter: str,
continuation_delimiter: str,
destination_path: str,
fewshot_random_seed: int = 1234,
prompt_string: str = '',
example_delimiter: str = '\n',
continuation_delimiter: str = ' ',
prelimiter: str = '',
context_key: str = 'context',
answer_key: str = 'answer',
Expand All @@ -580,8 +589,6 @@ def __init__(
cot_delimiter: str = '',
early_stopping_criteria: Optional[List[str]] = None,
do_normalization: bool = True,
*args: Any,
**kwargs: Any,
):
if tokenizer.eos_token_id is None:
raise ValueError(
Expand Down Expand Up @@ -626,8 +633,6 @@ def __init__(
static_keys=static_keys,
list_keys=list_keys,
tensor_keys=tensor_keys,
*args,
**kwargs,
)
# NOTE: set these after init call because they take class vars
self.early_stopping_criteria = early_stopping_criteria
Expand Down Expand Up @@ -781,9 +786,7 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:

def split_batch(self, batch: Any,
microbatch_size: Union[int, float]) -> Sequence[Any]:
"""Handling for certain specialty columns that must be split into.
batches in different formats.
"""Split batch handling for special columns.
Args:
batch (Dict): Batch of data
Expand Down Expand Up @@ -820,6 +823,7 @@ def split_batch(self, batch: Any,
for idx in range(num_chunks)]
return batched_list


class InContextLearningLMTaskDataset(InContextLearningDataset):
"""A dataset that constructs batches for in-context learning language.
Expand All @@ -840,11 +844,11 @@ def __init__(
max_seq_len: int,
pad_tok_id: int,
num_fewshot: int,
fewshot_random_seed: int,
prompt_string: str,
example_delimiter: str,
continuation_delimiter: str,
destination_path: str,
fewshot_random_seed: int = 1234,
prompt_string: str = '',
example_delimiter: str = '\n',
continuation_delimiter: str = ' ',
prelimiter: str = '',
context_key: str = 'context',
strip_dataset: bool = True,
Expand All @@ -855,8 +859,6 @@ def __init__(
generation_kwargs: Optional[Dict] = None,
static_keys: Optional[List] = None,
list_keys: Optional[List] = None,
*args: Any,
**kwargs: Any,
):
super().__init__(
dataset_uri=dataset_uri,
Expand Down Expand Up @@ -898,8 +900,6 @@ def __init__(
'labels': 'context',
},
padding_side='right',
*args,
**kwargs,
)


Expand Down Expand Up @@ -936,12 +936,14 @@ def __init__(
max_seq_len: int,
pad_tok_id: int,
num_fewshot: int,
fewshot_random_seed: int,
prompt_string: str,
example_delimiter: str,
continuation_delimiter: str,
destination_path: str,
fewshot_random_seed: int = 1234,
prompt_string: str = '',
example_delimiter: str = '\n',
continuation_delimiter: str = ' ',
prelimiter: str = '',
context_key: str = 'query',
tensor_keys: Optional[List] = None,
answer_key: str = 'answer',
strip_dataset: bool = True,
tokenize_labels: bool = True,
Expand All @@ -956,8 +958,6 @@ def __init__(
list_of_tensors_keys: Optional[List] = None,
list_of_tuples_keys: Optional[List] = None,
list_of_primitives: Optional[List] = None,
*args: Any,
**kwargs: Any,
):
self.choices_key = choices_key
base_batch = {
Expand All @@ -968,12 +968,10 @@ def __init__(
'gold_indices': [],
'choice_groupings': [],
}
context_key = kwargs.pop('context_key', 'query')
static_keys = kwargs.pop('static_keys', ['mode', 'generation_kwargs'])
tensor_keys = kwargs.pop(
'tensor_keys',
['input_ids', 'labels', 'attention_mask'],
)
if not static_keys:
static_keys = ['mode', 'generation_kwargs']
if not tensor_keys:
tensor_keys = ['input_ids', 'labels', 'attention_mask']
self.list_of_tensors_keys = list_of_tensors_keys or [
'continuation_indices',
]
Expand Down Expand Up @@ -1006,8 +1004,6 @@ def __init__(
static_keys=static_keys,
tensor_keys=tensor_keys,
padding_side='right',
*args,
**kwargs,
)
self.num_choices = len(self.dataset[0][self.choices_key])
self.batch_mapping_per_choice = {
Expand Down Expand Up @@ -1244,11 +1240,11 @@ def __init__(
max_seq_len: int,
pad_tok_id: int,
num_fewshot: int,
fewshot_random_seed: int,
prompt_string: str,
example_delimiter: str,
continuation_delimiter: str,
destination_path: str,
fewshot_random_seed: int = 1234,
prompt_string: str = '',
example_delimiter: str = '\n',
continuation_delimiter: str = ' ',
prelimiter: str = '',
answer_key: str = 'answer',
strip_dataset: bool = True,
Expand All @@ -1261,8 +1257,6 @@ def __init__(
generation_kwargs: Optional[Dict] = None,
list_keys: Optional[List] = None,
choices_key: str = 'context_options',
*args: Any,
**kwargs: Any,
):
static_keys = ['mode']
tensor_keys = ['input_ids', 'labels', 'attention_mask']
Expand Down Expand Up @@ -1295,8 +1289,6 @@ def __init__(
static_keys=static_keys,
tensor_keys=tensor_keys,
list_of_tensors_keys=list_of_tensors_keys,
*args,
**kwargs,
)
self.base_batch = {
'input_ids': [],
Expand Down Expand Up @@ -1480,28 +1472,16 @@ def tokenize_example(
tokenized_example['gold'] = example['gold']
return tokenized_example



def build_icl_dataloader(
icl_task_type: str,
dataset_uri: str,
tokenizer: transformers.PreTrainedTokenizerBase,
tokenizer: Union[transformers.PreTrainedTokenizer,
transformers.PreTrainedTokenizerFast],
batch_size: int,
max_seq_len: int,
pad_tok_id: int,
num_fewshot: int,
prompt_string: str, # e.g. 'translate english to french:'
example_delimiter: str, # e.g. '\n'
continuation_delimiter: str, # e.g. ''
hf_loading_vars: Dict,
hf_parsing_map: Dict,
destination_path: str,
prelimiter: str, # e.g. 'Question: '
cot_delimiter: str, # e.g. ' ### '
fewshot_random_seed: int,
generation_kwargs: Dict,
early_stopping_criteria: Optional[List[str]] = None,
do_normalization: bool = True,
destination_path: str = '',
kwargs: Optional[Dict[str, Any]] = None,
) -> DataSpec:
"""Factory method that builds the specific dataset for the specified.
icl_task_type. See documentation for `get_icl_task_dataloader` for argument
Expand All @@ -1513,28 +1493,16 @@ def build_icl_dataloader(
this might be different)
3. set the `split_batch` function if necessary
"""
name = icl_task_type
kwargs: Dict[str, Any] = {
# Add named parameters to kwargs
kwargs.update({
'dataset_uri': dataset_uri,
'tokenizer': tokenizer,
'max_seq_len':max_seq_len,
'pad_tok_id':pad_tok_id,
'num_fewshot': num_fewshot,
'prompt_string': prompt_string,
'example_delimiter': example_delimiter,
'continuation_delimiter': continuation_delimiter,
'destination_path': destination_path,
'prelimiter': prelimiter,
'fewshot_random_seed': fewshot_random_seed,
'hf_loading_vars': hf_loading_vars,
'hf_parsing_map': hf_parsing_map,
'cot_delimiter': cot_delimiter,
'early_stopping_criteria': early_stopping_criteria,
'do_normalization': do_normalization,
'generation_kwargs': generation_kwargs,
}
'destination_path': destination_path,
})
dataset = construct_from_registry(
name=name,
name=icl_task_type,
registry=registry.icl_datasets,
partial_function=False,
pre_validation_function=None,
Expand Down Expand Up @@ -1642,22 +1610,11 @@ def get_icl_task_dataloader(
tokenizer: Union[transformers.PreTrainedTokenizer,
transformers.PreTrainedTokenizerFast],
batch_size: int,
max_seq_len: int,
pad_tok_id: int,
num_fewshot: int,
prompt_string: str, # e.g. 'translate english to french:'
example_delimiter: str, # e.g. '\n'
continuation_delimiter: str = '',
destination_path: str = '',
question_prelimiter: str = '', # e.g. 'Question: '
fewshot_random_seed: int = 1234,
cot_delimiter: str = '',
hf_loading_vars: Dict,
hf_parsing_map: Dict,
has_categories: bool = False,
hf_loading_vars: Optional[Dict] = None,
hf_parsing_map: Optional[Dict] = None,
generation_kwargs: Optional[Dict] = None,
early_stopping_criteria: Optional[List[str]] = None,
do_normalization: bool = True,
destination_path: str = '',
kwargs: Optional[Dict[str, Any]] = None,
) -> Union[DataSpec, Dict[str, DataSpec]]:
r"""Constructs a dataloader (or dataloaders if has_categories is True)
Expand Down Expand Up @@ -1738,15 +1695,6 @@ def get_icl_task_dataloader(
Returns:
DataLoader: A dataloader used for performing in-context learning evaluation on the dataset provided.
"""
if hf_loading_vars is None:
hf_loading_vars = {}
if hf_parsing_map is None:
hf_parsing_map = {}
if generation_kwargs is None:
generation_kwargs = {}
if early_stopping_criteria is None:
early_stopping_criteria = []

if has_categories:
result_dls = {}
output_files = partition_dataset_by_category(
Expand All @@ -1763,21 +1711,10 @@ def get_icl_task_dataloader(
dataset_uri=partition_uri,
tokenizer=tokenizer,
batch_size=batch_size,
max_seq_len=max_seq_len,
pad_tok_id=pad_tok_id,
num_fewshot=num_fewshot,
prompt_string=prompt_string,
example_delimiter=example_delimiter,
continuation_delimiter=continuation_delimiter,
destination_path=partition_uri + '_tmp',
prelimiter=question_prelimiter,
cot_delimiter=cot_delimiter,
fewshot_random_seed=fewshot_random_seed,
hf_loading_vars=hf_loading_vars,
hf_parsing_map=hf_parsing_map,
generation_kwargs=generation_kwargs,
early_stopping_criteria=early_stopping_criteria,
do_normalization=do_normalization,
kwargs=kwargs
)
return result_dls
else:
Expand All @@ -1786,19 +1723,8 @@ def get_icl_task_dataloader(
dataset_uri=dataset_uri,
tokenizer=tokenizer,
batch_size=batch_size,
max_seq_len=max_seq_len,
pad_tok_id=pad_tok_id,
num_fewshot=num_fewshot,
prompt_string=prompt_string,
example_delimiter=example_delimiter,
hf_loading_vars=hf_loading_vars,
hf_parsing_map=hf_parsing_map,
continuation_delimiter=continuation_delimiter,
destination_path=destination_path,
prelimiter=question_prelimiter,
cot_delimiter=cot_delimiter,
fewshot_random_seed=fewshot_random_seed,
generation_kwargs=generation_kwargs,
early_stopping_criteria=early_stopping_criteria,
do_normalization=do_normalization,
kwargs=kwargs
)
7 changes: 4 additions & 3 deletions llmfoundry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from composer.models import ComposerModel
from composer.optim import ComposerScheduler
from torch.optim import Optimizer
from torch.utils.data import DataLoader as TorchDataloader, Dataset
from torch.utils.data import DataLoader as TorchDataloader
from torchmetrics import Metric
from transformers import PreTrainedTokenizerBase

from llmfoundry.eval import InContextLearningDataset
from llmfoundry.interfaces import CallbackWithConfig
from llmfoundry.layers_registry import (
attention_classes,
Expand Down Expand Up @@ -207,12 +208,12 @@
)

_icl_datasets_description = (
'The ICL dataloaders registry is used to register an torch.utils.data.Dataset class which can be used for ICL tasks.'
'The ICL datasets registry is used to register an torch.utils.data.Dataset class which can be used for ICL tasks.'
)
icl_datasets = create_registry(
'llmfoundry',
'icl_datasets',
generic_type=Type[Dataset],
generic_type=Type[InContextLearningDataset],
entry_points=True,
description=_icl_datasets_description,
)
Expand Down
Loading

0 comments on commit 1e37fd7

Please sign in to comment.