Skip to content

Commit

Permalink
x
Browse files Browse the repository at this point in the history
  • Loading branch information
sanjari-orb committed Jun 5, 2024
1 parent ac56dc5 commit 3a948a1
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 78 deletions.
6 changes: 6 additions & 0 deletions llmfoundry/eval/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
tokenizer_needs_prefix_space,
trim_context,
)
from llmfoundry.registry import icl_dataloaders

icl_dataloaders.register('multiple_choice', func=InContextLearningMultipleChoiceTaskDataset)
icl_dataloaders.register('schema', func=InContextLearningSchemaTaskDataset)
icl_dataloaders.register('language_modeling', func=InContextLearningLMTaskDataset)
icl_dataloaders.register('generation_task_with_answers', func=InContextLearningGenerationTaskWithAnswersDataset)

__all__ = [
'InContextLearningDataset',
Expand Down
115 changes: 38 additions & 77 deletions llmfoundry/eval/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
tokenizer_needs_prefix_space,
trim_context,
)
from llmfoundry.utils.registry_utils import construct_from_registry
from llmfoundry import registry

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -134,6 +136,8 @@ def __init__(
static_keys: Optional[List] = None,
list_keys: Optional[List] = None,
tensor_keys: Optional[List] = None,
*args: Any,
**kwargs: Any,
):
self.tokenizer = tokenizer
self.prefix_space = tokenizer_needs_prefix_space(self.tokenizer)
Expand Down Expand Up @@ -1323,86 +1327,43 @@ def build_icl_dataloader(
this might be different)
3. set the `split_batch` function if necessary
"""
if icl_task_type == 'multiple_choice':
dataset = InContextLearningMultipleChoiceTaskDataset(
dataset_uri=dataset_uri,
tokenizer=tokenizer,
max_seq_len=max_seq_len,
pad_tok_id=pad_tok_id,
num_fewshot=num_fewshot,
prompt_string=prompt_string,
example_delimiter=example_delimiter,
continuation_delimiter=continuation_delimiter,
destination_path=destination_path,
prelimiter=prelimiter,
fewshot_random_seed=fewshot_random_seed,
hf_loading_vars=hf_loading_vars,
hf_parsing_map=hf_parsing_map,
generation_kwargs=generation_kwargs,
)
batch_size = max(dataset.num_choices, batch_size)
effective_batchsize = batch_size // dataset.num_choices
elif icl_task_type == 'schema':
dataset = InContextLearningSchemaTaskDataset(
dataset_uri=dataset_uri,
tokenizer=tokenizer,
max_seq_len=max_seq_len,
pad_tok_id=pad_tok_id,
num_fewshot=num_fewshot,
prompt_string=prompt_string,
example_delimiter=example_delimiter,
continuation_delimiter=continuation_delimiter,
destination_path=destination_path,
prelimiter=prelimiter,
fewshot_random_seed=fewshot_random_seed,
hf_loading_vars=hf_loading_vars,
hf_parsing_map=hf_parsing_map,
generation_kwargs=generation_kwargs,
)
name = icl_task_type
kwargs: Dict[str, Any] = {
'dataset_uri': dataset_uri,
'tokenizer': tokenizer,
'max_seq_len':max_seq_len,
'pad_tok_id':pad_tok_id,
'num_fewshot': num_fewshot,
'prompt_string': prompt_string,
'example_delimiter': example_delimiter,
'continuation_delimiter': continuation_delimiter,
'destination_path': destination_path,
'prelimiter': prelimiter,
'fewshot_random_seed': fewshot_random_seed,
'hf_loading_vars': hf_loading_vars,
'hf_parsing_map': hf_parsing_map,
#'cot_delimiter': cot_delimiter,
#'early_stopping_criteria': early_stopping_criteria,
#'do_normalization': do_normalization,
#'generati#on_kwargs': generation_kwargs,
}
print("icl tasks type: ", name)
dataset = construct_from_registry(
name=name,
registry=registry.icl_dataloaders,
partial_function=False,
pre_validation_function=None,
post_validation_function=None,
kwargs=kwargs,
)
# !Highly unsafe!

print('XXXXBatch size', batch_size)
if hasattr(dataset, 'num_choices'):
batch_size = max(dataset.num_choices, batch_size)
effective_batchsize = batch_size // dataset.num_choices
elif icl_task_type == 'language_modeling':
dataset = InContextLearningLMTaskDataset(
dataset_uri=dataset_uri,
tokenizer=tokenizer,
max_seq_len=max_seq_len,
pad_tok_id=pad_tok_id,
num_fewshot=num_fewshot,
prompt_string=prompt_string,
example_delimiter=example_delimiter,
continuation_delimiter=continuation_delimiter,
destination_path=destination_path,
prelimiter=prelimiter,
fewshot_random_seed=fewshot_random_seed,
hf_loading_vars=hf_loading_vars,
hf_parsing_map=hf_parsing_map,
generation_kwargs=generation_kwargs,
)
effective_batchsize = batch_size
elif icl_task_type == 'generation_task_with_answers':
dataset = InContextLearningGenerationTaskWithAnswersDataset(
dataset_uri=dataset_uri,
tokenizer=tokenizer,
max_seq_len=max_seq_len,
pad_tok_id=pad_tok_id,
num_fewshot=num_fewshot,
prompt_string=prompt_string,
example_delimiter=example_delimiter,
continuation_delimiter=continuation_delimiter,
destination_path=destination_path,
prelimiter=prelimiter,
fewshot_random_seed=fewshot_random_seed,
hf_loading_vars=hf_loading_vars,
hf_parsing_map=hf_parsing_map,
cot_delimiter=cot_delimiter,
early_stopping_criteria=early_stopping_criteria,
do_normalization=do_normalization,
generation_kwargs=generation_kwargs,
)
effective_batchsize = batch_size
else:
raise Exception(f'Unrecognized ICL task type: {icl_task_type}')

effective_batchsize = batch_size
sampler = dist.get_sampler(dataset, drop_last=False, shuffle=False)

split_batch = None
Expand Down
14 changes: 13 additions & 1 deletion llmfoundry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from composer.models import ComposerModel
from composer.optim import ComposerScheduler
from torch.optim import Optimizer
from torch.utils.data import DataLoader as TorchDataloader
from torch.utils.data import DataLoader as TorchDataloader, Dataset
from torchmetrics import Metric
from transformers import PreTrainedTokenizerBase

Expand Down Expand Up @@ -206,6 +206,17 @@
description=_metrics_description,
)

_icl_dataloaders_description = (
'The ICL dataloaders registry is used to register an torch.utils.data.Dataset class which can be used for ICL tasks.'
)
icl_dataloaders = create_registry(
'llmfoundry',
'icl_dataloaders',
generic_type=Type[Dataset],
entry_points=True,
description=_icl_dataloaders_description,
)

__all__ = [
'loggers',
'callbacks',
Expand All @@ -228,4 +239,5 @@
'attention_classes',
'attention_implementations',
'fcs',
'icl_dataloaders',
]

0 comments on commit 3a948a1

Please sign in to comment.