diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index bc712a7504..a29dee7683 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -159,7 +159,6 @@ def __init__(self, f'local directory {local} does not contain split {split}' ) - # Build Dataset super().__init__( local=local, remote=remote, @@ -345,51 +344,57 @@ def build_from_hf( with dist.local_rank_zero_download_and_wait(signal_file_path): pass - dataset = hf_datasets.load_dataset(dataset_name, split=split, **kwargs) - - def dataset_mapper(example: Dict): - if preprocessing_fn is not None: - example = preprocessing_fn(example) - return _tokenize_formatted_example(example, tokenizer) - - detected_cpu_count = os.cpu_count() or 1 - detected_cpus_with_margin = detected_cpu_count - 8 - num_cpus_to_use = max(1, detected_cpus_with_margin) - - columns_to_remove = list(dataset[0].keys()) - tokenized_dataset = dataset.map( - dataset_mapper, - batched=False, - remove_columns=columns_to_remove, - num_proc=num_cpus_to_use, - desc='Tokenizing dataset', - ) - - pad_token_id = tokenizer.pad_token_id - - def filter_long_or_empty_examples(example: Dict) -> bool: - less_than_max_seq_len = len(example['input_ids']) < max_seq_len - non_empty_input = len(example['input_ids']) > 0 - non_empty_labels = len(example['labels']) > 0 - non_padding_response = any( - token_id != pad_token_id for token_id in example['labels']) - return (less_than_max_seq_len and non_empty_input and - non_empty_labels and non_padding_response) - - filtered_dataset = tokenized_dataset.filter( - filter_long_or_empty_examples, - num_proc=num_cpus_to_use, - desc='Filtering out long prompts', - ) + error: Optional[Exception] = None + filtered_dataset = None + try: + dataset = hf_datasets.load_dataset(dataset_name, + split=split, + **kwargs) + + def dataset_mapper(example: Dict): + if preprocessing_fn is not None: + example = preprocessing_fn(example) + return _tokenize_formatted_example(example, tokenizer) + + detected_cpu_count = os.cpu_count() or 1 + detected_cpus_with_margin = detected_cpu_count - 8 + num_cpus_to_use = max(1, detected_cpus_with_margin) + + columns_to_remove = list(dataset[0].keys()) + tokenized_dataset = dataset.map( + dataset_mapper, + batched=False, + remove_columns=columns_to_remove, + num_proc=num_cpus_to_use, + desc='Tokenizing dataset', + ) - examples_removed = len(tokenized_dataset) - len(filtered_dataset) - if examples_removed > 0: - warnings.warn( - f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}, ' - + - 'the prompt or response was empty, or the response was all padding tokens.' + pad_token_id = tokenizer.pad_token_id + + def filter_long_or_empty_examples(example: Dict) -> bool: + less_than_max_seq_len = len(example['input_ids']) < max_seq_len + non_empty_input = len(example['input_ids']) > 0 + non_empty_labels = len(example['labels']) > 0 + non_padding_response = any( + token_id != pad_token_id for token_id in example['labels']) + return (less_than_max_seq_len and non_empty_input and + non_empty_labels and non_padding_response) + + filtered_dataset = tokenized_dataset.filter( + filter_long_or_empty_examples, + num_proc=num_cpus_to_use, + desc='Filtering out long prompts', ) + examples_removed = len(tokenized_dataset) - len(filtered_dataset) + if examples_removed > 0: + warnings.warn( + f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}, ' + + + 'the prompt or response was empty, or the response was all padding tokens.' + ) + except Exception as e: + error = e # Now local rank 0 indicates to the other ranks that it is done if dist.get_local_rank() == 0: log.debug('Local rank 0 finished data prep') @@ -403,7 +408,11 @@ def filter_long_or_empty_examples(example: Dict) -> bool: if dist.get_local_rank() == 0: os.remove(signal_file_path) + if error is not None: + log.error('Error during data prep') + raise error log.debug('All ranks finished data prep') + assert filtered_dataset is not None return filtered_dataset def build_from_streaming(self, *args: Any, diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 4e1fd6f1f8..c35d29f74d 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -308,18 +308,23 @@ def test_finetuning_dataloader(decoder_only_format: bool, @pytest.mark.world_size(2) +@pytest.mark.gpu @pytest.mark.parametrize('dataset_size', [4, 8]) @pytest.mark.parametrize('device_batch_size', [2, 4]) @pytest.mark.parametrize('drop_last', [True, False]) +@pytest.mark.parametrize('invalid_dataset', [True, False]) def test_finetuning_dataloader_small_data(dataset_size: int, device_batch_size: int, - drop_last: bool): + drop_last: bool, + invalid_dataset: bool): tokenizer_name = 'gpt2' max_seq_len = 2048 tiny_dataset_folder_path = os.path.join(os.getcwd(), 'test-ift-data-small') tiny_dataset_path = os.path.join(tiny_dataset_folder_path, 'train.jsonl') if dist.get_global_rank() == 0: - make_tiny_ft_dataset(path=tiny_dataset_path, size=dataset_size) + make_tiny_ft_dataset(path=tiny_dataset_path, + size=dataset_size, + add_bad_data_error=invalid_dataset) cfg = { 'name': 'finetuning', @@ -353,6 +358,11 @@ def test_finetuning_dataloader_small_data(dataset_size: int, error_context = contextlib.nullcontext() if (dist.get_world_size() * device_batch_size > dataset_size) and drop_last: error_context = pytest.raises(ValueError, match='Your dataset') + if invalid_dataset: + error_context = pytest.raises( + TypeError, + match='Unable to tokenize example because "prompt" was not a string' + ) with error_context: _ = build_finetuning_dataloader(cfg, tokenizer, device_batch_size)