diff --git a/llmfoundry/eval/datasets/__init__.py b/llmfoundry/eval/datasets/__init__.py index 02a2b88b21..c15094a6e0 100644 --- a/llmfoundry/eval/datasets/__init__.py +++ b/llmfoundry/eval/datasets/__init__.py @@ -22,6 +22,12 @@ tokenizer_needs_prefix_space, trim_context, ) +from llmfoundry.registry import icl_dataloaders + +icl_dataloaders.register('multiple_choice', func=InContextLearningMultipleChoiceTaskDataset) +icl_dataloaders.register('schema', func=InContextLearningSchemaTaskDataset) +icl_dataloaders.register('language_modeling', func=InContextLearningLMTaskDataset) +icl_dataloaders.register('generation_task_with_answers', func=InContextLearningGenerationTaskWithAnswersDataset) __all__ = [ 'InContextLearningDataset', diff --git a/llmfoundry/eval/datasets/in_context_learning_evaluation.py b/llmfoundry/eval/datasets/in_context_learning_evaluation.py index debb0dbc6f..75dc513e1e 100644 --- a/llmfoundry/eval/datasets/in_context_learning_evaluation.py +++ b/llmfoundry/eval/datasets/in_context_learning_evaluation.py @@ -29,6 +29,8 @@ tokenizer_needs_prefix_space, trim_context, ) +from llmfoundry.utils.registry_utils import construct_from_registry +from llmfoundry import registry log = logging.getLogger(__name__) @@ -134,6 +136,8 @@ def __init__( static_keys: Optional[List] = None, list_keys: Optional[List] = None, tensor_keys: Optional[List] = None, + *args: Any, + **kwargs: Any, ): self.tokenizer = tokenizer self.prefix_space = tokenizer_needs_prefix_space(self.tokenizer) @@ -1323,86 +1327,43 @@ def build_icl_dataloader( this might be different) 3. set the `split_batch` function if necessary """ - if icl_task_type == 'multiple_choice': - dataset = InContextLearningMultipleChoiceTaskDataset( - dataset_uri=dataset_uri, - tokenizer=tokenizer, - max_seq_len=max_seq_len, - pad_tok_id=pad_tok_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter=example_delimiter, - continuation_delimiter=continuation_delimiter, - destination_path=destination_path, - prelimiter=prelimiter, - fewshot_random_seed=fewshot_random_seed, - hf_loading_vars=hf_loading_vars, - hf_parsing_map=hf_parsing_map, - generation_kwargs=generation_kwargs, - ) - batch_size = max(dataset.num_choices, batch_size) - effective_batchsize = batch_size // dataset.num_choices - elif icl_task_type == 'schema': - dataset = InContextLearningSchemaTaskDataset( - dataset_uri=dataset_uri, - tokenizer=tokenizer, - max_seq_len=max_seq_len, - pad_tok_id=pad_tok_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter=example_delimiter, - continuation_delimiter=continuation_delimiter, - destination_path=destination_path, - prelimiter=prelimiter, - fewshot_random_seed=fewshot_random_seed, - hf_loading_vars=hf_loading_vars, - hf_parsing_map=hf_parsing_map, - generation_kwargs=generation_kwargs, - ) + name = icl_task_type + kwargs: Dict[str, Any] = { + 'dataset_uri': dataset_uri, + 'tokenizer': tokenizer, + 'max_seq_len':max_seq_len, + 'pad_tok_id':pad_tok_id, + 'num_fewshot': num_fewshot, + 'prompt_string': prompt_string, + 'example_delimiter': example_delimiter, + 'continuation_delimiter': continuation_delimiter, + 'destination_path': destination_path, + 'prelimiter': prelimiter, + 'fewshot_random_seed': fewshot_random_seed, + 'hf_loading_vars': hf_loading_vars, + 'hf_parsing_map': hf_parsing_map, + #'cot_delimiter': cot_delimiter, + #'early_stopping_criteria': early_stopping_criteria, + #'do_normalization': do_normalization, + #'generati#on_kwargs': generation_kwargs, + } + print("icl tasks type: ", name) + dataset = construct_from_registry( + name=name, + registry=registry.icl_dataloaders, + partial_function=False, + pre_validation_function=None, + post_validation_function=None, + kwargs=kwargs, + ) + # !Highly unsafe! + + print('XXXXBatch size', batch_size) + if hasattr(dataset, 'num_choices'): batch_size = max(dataset.num_choices, batch_size) effective_batchsize = batch_size // dataset.num_choices - elif icl_task_type == 'language_modeling': - dataset = InContextLearningLMTaskDataset( - dataset_uri=dataset_uri, - tokenizer=tokenizer, - max_seq_len=max_seq_len, - pad_tok_id=pad_tok_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter=example_delimiter, - continuation_delimiter=continuation_delimiter, - destination_path=destination_path, - prelimiter=prelimiter, - fewshot_random_seed=fewshot_random_seed, - hf_loading_vars=hf_loading_vars, - hf_parsing_map=hf_parsing_map, - generation_kwargs=generation_kwargs, - ) - effective_batchsize = batch_size - elif icl_task_type == 'generation_task_with_answers': - dataset = InContextLearningGenerationTaskWithAnswersDataset( - dataset_uri=dataset_uri, - tokenizer=tokenizer, - max_seq_len=max_seq_len, - pad_tok_id=pad_tok_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter=example_delimiter, - continuation_delimiter=continuation_delimiter, - destination_path=destination_path, - prelimiter=prelimiter, - fewshot_random_seed=fewshot_random_seed, - hf_loading_vars=hf_loading_vars, - hf_parsing_map=hf_parsing_map, - cot_delimiter=cot_delimiter, - early_stopping_criteria=early_stopping_criteria, - do_normalization=do_normalization, - generation_kwargs=generation_kwargs, - ) - effective_batchsize = batch_size else: - raise Exception(f'Unrecognized ICL task type: {icl_task_type}') - + effective_batchsize = batch_size sampler = dist.get_sampler(dataset, drop_last=False, shuffle=False) split_batch = None diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index 0c8e64b759..5910e449b3 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -7,7 +7,7 @@ from composer.models import ComposerModel from composer.optim import ComposerScheduler from torch.optim import Optimizer -from torch.utils.data import DataLoader as TorchDataloader +from torch.utils.data import DataLoader as TorchDataloader, Dataset from torchmetrics import Metric from transformers import PreTrainedTokenizerBase @@ -206,6 +206,17 @@ description=_metrics_description, ) +_icl_dataloaders_description = ( + 'The ICL dataloaders registry is used to register an torch.utils.data.Dataset class which can be used for ICL tasks.' +) +icl_dataloaders = create_registry( + 'llmfoundry', + 'icl_dataloaders', + generic_type=Type[Dataset], + entry_points=True, + description=_icl_dataloaders_description, +) + __all__ = [ 'loggers', 'callbacks', @@ -228,4 +239,5 @@ 'attention_classes', 'attention_implementations', 'fcs', + 'icl_dataloaders', ]