Skip to content

Commit

Permalink
Custom Exceptions for Mosaic Logger (#1014)
Browse files Browse the repository at this point in the history
* flagged areas to throw ft errors + added custom exception

* added exceptions for all files

* fixed type of invalid type errors

* fixed merge

* added exceptions for all files

* fixed merge

* added a catch for missing hf url

* expain dataloader catch to all error types

* added tests for custom exceptions

* fixed a couple tests

* fixed some more tests

* addressed comments

* fixed formatting, updated split correction, moved logger setup into helper

* added check for repeating roles

* more strict checking for prompt response type

* removed some cases we don't need

* updated name of unknown conversation type error

* formatting

* Fix multi model eval (#1055)

* resolved merge conflict

* formatted

* added back two imports

* formatting changes

* formatted again

* sorted imports

* diable yapf for exceptions import

* disabled yapf on test dataloader

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
jjanezhang and dakinggg authored Mar 26, 2024
1 parent e590acf commit 0ef7cd6
Show file tree
Hide file tree
Showing 13 changed files with 673 additions and 141 deletions.
27 changes: 11 additions & 16 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
dataset_constructor)
from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio
from llmfoundry.data.text_data import build_streams, get_tokens_per_batch_func
from llmfoundry.utils.exceptions import (MissingHuggingFaceURLSplitError,
NotEnoughDatasetSamplesError)

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -174,15 +176,12 @@ def build_finetuning_dataloader(cfg: DictConfig,
# Build HF dataloader
dataset_name_or_path = cfg.dataset.hf_name
split = cfg.dataset.get('split')
if split is None:
raise MissingHuggingFaceURLSplitError()

# If dataset is a remote path, download it first.
backend, _, _ = parse_uri(dataset_name_or_path)
if backend not in ['', None]:
if split is None:
raise ValueError(
'When using a HuggingFace dataset from a URL, you must set the ' + \
'`split` key in the dataset config.'
)
dataset_name_or_path = _download_remote_hf_dataset(
remote_path=dataset_name_or_path, split=split)
split = split.replace('-', '_')
Expand Down Expand Up @@ -218,17 +217,13 @@ def build_finetuning_dataloader(cfg: DictConfig,
if hasattr(dataset, '__len__'):
full_dataset_size = len(dataset)
if full_dataset_size < minimum_dataset_size:
raise ValueError(
f'Your dataset (name={cfg.dataset.hf_name}, split={split}) '
+
f'has {full_dataset_size} samples, but your minimum batch size '
+
f'is {minimum_dataset_size} because you are running on {world_size} gpus and '
+
f'your per device batch size is {dataloader_batch_size}. Please increase the number '
+
f'of samples in your dataset to at least {minimum_dataset_size}.'
)
raise NotEnoughDatasetSamplesError(
dataset_name=cfg.dataset.hf_name,
split=split,
dataloader_batch_size=dataloader_batch_size,
world_size=world_size,
full_dataset_size=full_dataset_size,
minimum_dataset_size=minimum_dataset_size)
# Initialize sampler.
sampler = dist.get_sampler(dataset,
drop_last=cfg.drop_last,
Expand Down
107 changes: 52 additions & 55 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
from collections.abc import Mapping
from functools import partial
from pathlib import Path
from typing import (Any, Callable, Dict, List, Literal, Optional, Sequence, Set,
from typing import (Any, Callable, Dict, List, Literal, Optional, Sequence,
Tuple, Union, cast)

import datasets as hf_datasets
Expand All @@ -51,6 +51,21 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
from llmfoundry.data.finetuning.collator import (_HF_IGNORE_INDEX,
stitch_turns_decoder_only,
stitch_turns_encoder_decoder)
# yapf: disable
from llmfoundry.utils.exceptions import (ConsecutiveRepeatedChatRolesError,
IncorrectMessageKeyQuantityError,
InvalidContentTypeError,
InvalidFileExtensionError,
InvalidLastChatMessageRoleError,
InvalidPromptResponseKeysError,
InvalidPromptTypeError,
InvalidResponseTypeError,
InvalidRoleError,
NotEnoughChatDataError,
TooManyKeysInExampleError,
UnableToProcessPromptResponseError,
UnknownExampleTypeError)
# yapf: enable
from llmfoundry.utils.logging_utils import SpecificWarningFilter

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -94,13 +109,11 @@ def _get_example_type(example: Example) -> ExampleType:
if any(allowed_message_key in example
for allowed_message_key in _ALLOWED_MESSAGES_KEYS):
return 'chat'
elif any([
pr in example
for pr in _ALLOWED_PROMPT_KEYS.union(_ALLOWED_RESPONSE_KEYS)
]):
elif any(p in example for p in _ALLOWED_PROMPT_KEYS) and any(
r in example for r in _ALLOWED_RESPONSE_KEYS):
return 'prompt_response'
else:
raise KeyError(f'Unknown conversation type {example=}')
raise UnknownExampleTypeError(example)


def _is_empty_or_nonexistent(dirpath: str) -> bool:
Expand All @@ -115,15 +128,14 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool:
return not os.path.isdir(dirpath) or len(os.listdir(dirpath)) == 0


def _get_key(dictionary: Mapping[str, Any], allowed_keys: Set[str]):
def _get_key(dictionary: Mapping[str, Any], allowed_keys: set[str]):
if not isinstance(dictionary, Mapping):
raise TypeError(
f'Expected dictionary to be a mapping, but found {type(dictionary)}'
)
desired_keys = allowed_keys.intersection(dictionary.keys())
if len(desired_keys) != 1:
raise ValueError(
f'Dictionary has multiple keys in `allowed_keys`: {desired_keys}')
raise TooManyKeysInExampleError(allowed_keys, desired_keys)
return list(desired_keys)[0]


Expand All @@ -136,26 +148,29 @@ def _validate_chat_formatted_example(example: ChatFormattedDict):
raise TypeError(
f'Expected messages to be an iterable, but found {type(messages)}')
if len(messages) <= 1:
raise ValueError('Chat example must have at least two messages')
raise NotEnoughChatDataError()

last_message = messages[-1]
role_key = _get_key(last_message, _ALLOWED_ROLE_KEYS)
last_role = last_message[role_key]
if last_role not in _ALLOWED_LAST_MESSAGE_ROLES:
raise ValueError(f'Invalid last message role: {last_role}')
raise InvalidLastChatMessageRoleError(last_role,
_ALLOWED_LAST_MESSAGE_ROLES)

last_message_role = None
for message in messages:
role_key, content_key = _get_key(message, _ALLOWED_ROLE_KEYS), _get_key(
message, _ALLOWED_CONTENT_KEYS)
if len(message.keys()) != 2:
raise ValueError(
f'Expected 2 keys in message, but found {len(message.keys())}')
raise IncorrectMessageKeyQuantityError(list(message.keys()))
if message[role_key] not in _ALLOWED_ROLES:
raise ValueError(f'Invalid role: {message[role_key]}')
raise InvalidRoleError(message[role_key], _ALLOWED_ROLES)
if not isinstance(message[content_key], str):
raise TypeError(
f'Expected content to be a string, but found {type(message[content_key])}'
)
raise InvalidContentTypeError(type(message[content_key]))
if last_message_role is not None and last_message_role == message[
role_key]:
raise ConsecutiveRepeatedChatRolesError(last_message_role)
last_message_role = message[role_key]


def _slice_chat_formatted_example(
Expand All @@ -182,8 +197,8 @@ def _slice_chat_formatted_example(

last_message = messages[-1]
if last_message['role'] != 'assistant':
raise ValueError(
f'last message must be from assistant. {last_message=}')
raise InvalidLastChatMessageRoleError(last_message['role'],
set(['assistant']))

def slice_out_last_turn(
messages_through_current_turn: List[Dict[str, str]],
Expand Down Expand Up @@ -291,31 +306,20 @@ def _tokenize_prompt_response_formatted_example(
response_keys = example_keys.intersection(_ALLOWED_RESPONSE_KEYS)

if len(prompt_keys) != 1:
raise KeyError(
f'Unable to tokenize example because {len(prompt_keys)} of the allowed prompt keys ' +\
f'were present in {example_keys=}. Please specify exactly one. {_ALLOWED_PROMPT_KEYS=}'
)
raise TooManyKeysInExampleError(_ALLOWED_PROMPT_KEYS, prompt_keys)

if len(response_keys) != 1:
raise KeyError(
f'Unable to tokenize example because {len(response_keys)} of the allowed response keys ' +\
f'were present in {example_keys=}. Please specify exactly one. {_ALLOWED_RESPONSE_KEYS=}'
)
raise TooManyKeysInExampleError(_ALLOWED_RESPONSE_KEYS, response_keys)

prompt_key = prompt_keys.pop()
response_key = response_keys.pop()
prompt = example[prompt_key]
response = example[response_key]

if not isinstance(prompt, str):
raise TypeError(
f'Unable to tokenize example because {prompt_key} was not a string. {example=}'
)
raise InvalidPromptTypeError(type(prompt))

if not isinstance(response, str):
raise TypeError(
f'Unable to tokenize example because {response_key} was not a string. {example=}'
)
raise InvalidResponseTypeError(type(response))

# Note: We default to the tokenizer's add_bos_token and add_eos_token behavior here
# (which we do not do for chat-formatted examples). This is because chat examples specifically
Expand Down Expand Up @@ -360,7 +364,7 @@ def tokenize_formatted_example(
return _tokenize_prompt_response_formatted_example(
prompt_response_example, tokenizer)
else:
raise ValueError(f'Unknown conversation type {example_format=}')
raise UnknownExampleTypeError(example)


def is_valid_ift_example(max_seq_len: int, target_prompts: str,
Expand Down Expand Up @@ -428,7 +432,7 @@ def _stream_remote_local_validate(remote: Optional[str], local: Optional[str],
contents = set(os.listdir(local))
if split is not None and split not in contents:
raise ValueError(
f'local directory {local} does not contain split {split}')
f'Local directory {local} does not contain split {split}')


class StreamingFinetuningDataset(StreamingDataset):
Expand Down Expand Up @@ -636,9 +640,7 @@ def get_preprocessing_fn_from_dict(

def _preprocessor(example: Dict[str, Any]) -> Dict[str, str]:
if list(mapping.keys()) != ['prompt', 'response']:
raise ValueError(
f'Expected {mapping=} to have keys "prompt" and "response".'
)
raise InvalidPromptResponseKeysError(mapping, example)
return {
'prompt': example[mapping['prompt']],
'response': example[mapping['response']]
Expand Down Expand Up @@ -697,9 +699,8 @@ def get_preprocessing_fn_from_str(
return preprocessing_fn

def build_from_hf(
self, dataset_name: str, split: Optional[str], safe_load: bool,
max_seq_len: int, preprocessing_fn: Optional[Callable[[dict[str, Any]],
dict[str, str]]],
self, dataset_name: str, split: str, safe_load: bool, max_seq_len: int,
preprocessing_fn: Optional[Callable[[dict[str, Any]], dict[str, str]]],
tokenizer: PreTrainedTokenizerBase, target_prompts: str,
target_responses: str, decoder_only_format: bool, hf_kwargs: Dict[str,
Any]
Expand Down Expand Up @@ -758,9 +759,8 @@ def build_from_hf(
local_dir_use_symlinks=False,
local_dir=local_dataset_dir)
if _is_empty_or_nonexistent(dirpath=local_dataset_dir):
raise FileNotFoundError(
f'safe_load is set to True. No data files with safe extensions {SUPPORTED_EXTENSIONS} '
+ f'found for dataset {dataset_name}. ')
raise InvalidFileExtensionError(
dataset_name, SUPPORTED_EXTENSIONS)
# Set dataset_name to the downloaded location.
dataset_name = local_dataset_dir

Expand All @@ -774,9 +774,9 @@ def build_from_hf(
if not all(
Path(f).suffix in SUPPORTED_EXTENSIONS
for f in dataset_files):
raise ValueError(
f'Dataset at local path {dataset_name} contains invalid file types. '
+ f'Allowed file types are: {SUPPORTED_EXTENSIONS}')
raise InvalidFileExtensionError(dataset_name,
SUPPORTED_EXTENSIONS)

dataset = hf_datasets.load_dataset(dataset_name,
split=split,
**hf_kwargs)
Expand Down Expand Up @@ -853,9 +853,8 @@ def alpaca_preprocessing_function(inp: Dict) -> Dict[str, str]:
prompt, response = inp['text'].split('### Response:')
prompt += '### Response:'
except Exception as e:
raise ValueError(
f"Unable to extract prompt/response from 'text'={inp['text']}"
) from e
raise UnableToProcessPromptResponseError(inp) from e

return {'prompt': prompt, 'response': response}


Expand All @@ -871,8 +870,7 @@ def dolly_preprocessing_function(inp: Dict) -> Dict[str, str]:
prompt = PROMPT_FORMAT.format(instruction=instruction)
response = inp['output']
except Exception as e:
raise ValueError(
f'Unable to extract prompt/response from {inp=}') from e
raise UnableToProcessPromptResponseError(inp) from e
return {'prompt': prompt, 'response': response}


Expand All @@ -898,6 +896,5 @@ def muennighoff_tokenize_function(inp: Dict) -> Dict[str, str]:
response.startswith(transitions)):
response = ' ' + response
except Exception as e:
raise ValueError(
f'Unable to process prompt/response from {inp=}') from e
raise UnableToProcessPromptResponseError(inp) from e
return {'prompt': prompt, 'response': response}
Loading

0 comments on commit 0ef7cd6

Please sign in to comment.