Skip to content

Commit

Permalink
Merge branch 'main' into milo/foundry-type-cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress authored Apr 15, 2024
2 parents 7f3d913 + f01f625 commit dd4a926
Show file tree
Hide file tree
Showing 65 changed files with 7,188 additions and 441 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
my-copy-c4*/
my-copy-arxiv*/
*.jsonl*
!tests/eval/local_data/*.jsonl

# WandB
wandb/
Expand Down
4 changes: 1 addition & 3 deletions llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
flash_attn_fn, scaled_multihead_dot_product_attention)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
from llmfoundry.models.layers.ffn import MPTMLP
from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig,
MPTForCausalLM, MPTModel, MPTPreTrainedModel)
from llmfoundry.tokenizers import TiktokenTokenizerWrapper
Expand All @@ -37,9 +37,7 @@
'build_finetuning_dataloader',
'Seq2SeqFinetuningCollator',
'MPTBlock',
'FFN_CLASS_REGISTRY',
'MPTMLP',
'build_ffn',
'MPTConfig',
'MPTPreTrainedModel',
'MPTModel',
Expand Down
43 changes: 31 additions & 12 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
_ALLOWED_MESSAGES_KEYS = {'messages'}
_ALLOWED_ROLE_KEYS = {'role'}
_ALLOWED_CONTENT_KEYS = {'content'}
_ALLOWED_ROLES = {'user', 'assistant', 'system'}
_ALLOWED_ROLES = {'user', 'assistant', 'system', 'tool'}
_ALLOWED_LAST_MESSAGE_ROLES = {'assistant'}
DOWNLOADED_FT_DATASETS_DIRPATH = os.path.abspath(
os.path.join(os.path.realpath(__file__), os.pardir, os.pardir, os.pardir,
Expand Down Expand Up @@ -217,7 +217,7 @@ def slice_out_last_turn(
if conversation_through_previous_turn != prompt_with_history[:len(
conversation_through_previous_turn)]:
raise ValueError(
f'The prompt_with_histry must start with the conversation through the previous turn. {conversation_through_previous_turn=}, {prompt_with_history=}'
f'The prompt_with_history must start with the conversation through the previous turn. {conversation_through_previous_turn=}, {prompt_with_history=}'
)
if prompt_with_history != full_conversation[:len(prompt_with_history)]:
raise ValueError(
Expand Down Expand Up @@ -624,9 +624,8 @@ def print_registered_tasks(self) -> None:
log.info('\n'.join(tasks))

def get_preprocessing_fn_from_dict(
self,
mapping: Dict[str,
str]) -> Callable[[Dict[str, Any]], Dict[str, str]]:
self, mapping: Dict[str,
str]) -> Callable[[Dict[str, Any]], Example]:
"""Get a preprocessing function from a dictionary.
The dictionary maps column names in the dataset to "prompt" and "response".
Expand Down Expand Up @@ -662,7 +661,7 @@ def get_preprocessing_fn_from_str(
self,
preprocessor: Optional[str],
dataset_name: Optional[str] = None
) -> Optional[Callable[[Dict[str, Any]], Dict[str, str]]]:
) -> Optional[Callable[[Dict[str, Any]], Example]]:
"""Get a preprocessing function from a string.
String can be either a registered function or an import path.
Expand Down Expand Up @@ -710,7 +709,7 @@ def get_preprocessing_fn_from_str(

def build_from_hf(
self, dataset_name: str, split: str, safe_load: bool, max_seq_len: int,
preprocessing_fn: Optional[Callable[[dict[str, Any]], dict[str, str]]],
preprocessing_fn: Optional[Callable[[dict[str, Any]], Example]],
tokenizer: PreTrainedTokenizerBase, target_prompts: str,
target_responses: str, decoder_only_format: bool, hf_kwargs: Dict[str,
Any]
Expand Down Expand Up @@ -793,7 +792,8 @@ def build_from_hf(

def dataset_mapper(example: Dict):
if preprocessing_fn is not None:
example = preprocessing_fn(example)
return tokenize_formatted_example(preprocessing_fn(example),
tokenizer)
return tokenize_formatted_example(example, tokenizer)

detected_cpu_count = os.cpu_count() or 1
Expand Down Expand Up @@ -857,7 +857,7 @@ def build_from_streaming(self, *args: Any,


@dataset_constructor.register('tatsu-lab/alpaca')
def alpaca_preprocessing_function(inp: Dict) -> Dict[str, str]:
def alpaca_preprocessing_function(inp: Dict) -> PromptResponseDict:
"""Split out prompt/response from text."""
try:
prompt, response = inp['text'].split('### Response:')
Expand All @@ -869,7 +869,7 @@ def alpaca_preprocessing_function(inp: Dict) -> Dict[str, str]:


@dataset_constructor.register('HuggingFaceH4/databricks_dolly_15k')
def dolly_preprocessing_function(inp: Dict) -> Dict[str, str]:
def dolly_preprocessing_function(inp: Dict) -> PromptResponseDict:
"""Format the text string."""
PROMPT_FORMAT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n'
try:
Expand All @@ -885,7 +885,7 @@ def dolly_preprocessing_function(inp: Dict) -> Dict[str, str]:


@dataset_constructor.register('bigscience/P3')
def p3_preprocessing_function(inp: Dict) -> Dict[str, str]:
def p3_preprocessing_function(inp: Dict) -> PromptResponseDict:
"""Format the already-split example."""
return {
'prompt': inp['inputs'] + ':',
Expand All @@ -895,7 +895,7 @@ def p3_preprocessing_function(inp: Dict) -> Dict[str, str]:

# Muennighoff's P3 and flan datasets share a similar convention
@dataset_constructor.register('Muennighoff/P3', 'Muennighoff/flan')
def muennighoff_tokenize_function(inp: Dict) -> Dict[str, str]:
def muennighoff_tokenize_function(inp: Dict) -> PromptResponseDict:
"""Format the already-split example."""
try:
prompt: str = inp['inputs']
Expand All @@ -908,3 +908,22 @@ def muennighoff_tokenize_function(inp: Dict) -> Dict[str, str]:
except Exception as e:
raise UnableToProcessPromptResponseError(inp) from e
return {'prompt': prompt, 'response': response}


@dataset_constructor.register('teknium/OpenHermes-2.5')
def shareGPT_format_preprocessor(inp: Dict) -> ChatFormattedDict:
"""Convert from ShareGPT format to our chat format."""
role_map = {
'human': 'user',
'gpt': 'assistant',
}
try:
conversation = inp['conversations']
messages: List[Dict[str, str]] = []
for message in conversation:
role: str = role_map.get(message['from'], message['from'])
content: str = message['value']
messages.append({'role': role, 'content': content})
except Exception as e:
raise UnableToProcessPromptResponseError(inp) from e
return {'messages': messages}
2 changes: 2 additions & 0 deletions llmfoundry/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
31 changes: 31 additions & 0 deletions llmfoundry/eval/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Natively supported in-context learning evaluation datasets."""

from llmfoundry.eval.datasets.in_context_learning_evaluation import (
InContextLearningCodeEvalDataset, InContextLearningDataset,
InContextLearningGenerationTaskWithAnswersDataset,
InContextLearningLMTaskDataset, InContextLearningMultipleChoiceTaskDataset,
InContextLearningSchemaTaskDataset, get_icl_task_dataloader)
from llmfoundry.eval.datasets.utils import (get_continuation_span,
get_fewshot_sample_idxs,
make_padded_input, strip_data,
tokenizer_needs_prefix_space,
trim_context)

__all__ = [
'InContextLearningDataset',
'InContextLearningGenerationTaskWithAnswersDataset',
'InContextLearningLMTaskDataset',
'InContextLearningCodeEvalDataset',
'InContextLearningMultipleChoiceTaskDataset',
'InContextLearningSchemaTaskDataset',
'get_icl_task_dataloader',
'strip_data',
'tokenizer_needs_prefix_space',
'trim_context',
'get_continuation_span',
'get_fewshot_sample_idxs',
'make_padded_input',
]
Loading

0 comments on commit dd4a926

Please sign in to comment.