Skip to content

Commit

Permalink
Raise DatasetTooSmall exception if canonical nodes is less than num s…
Browse files Browse the repository at this point in the history
…amples (#1518)

Co-authored-by: Saaketh Narayan <[email protected]>
Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
3 people authored Sep 12, 2024
1 parent 6d93260 commit 5465db4
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 5 deletions.
4 changes: 3 additions & 1 deletion llmfoundry/command_utils/data_prep/convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,9 @@ def convert_text_to_mds(
index_path = os.path.join(local_output_folder, 'index.json')
with open(index_path, 'r') as index_file:
if not json.load(index_file)['shards']:
raise DatasetTooSmallError()
raise DatasetTooSmallError(
reason='No shards were created when converting text to MDS.',
)

# Write a done file with the args and object names
write_done_file(local_output_folder, args_str, object_names)
Expand Down
20 changes: 19 additions & 1 deletion llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
ALLOWED_RESPONSE_KEYS,
ChatTemplateError,
ConsecutiveRepeatedChatRolesError,
DatasetTooSmallError,
IncorrectMessageKeyQuantityError,
InvalidContentTypeError,
InvalidExampleTypeError,
Expand Down Expand Up @@ -1033,7 +1034,24 @@ def build_from_streaming(
*args: Any,
**kwargs: Any,
) -> StreamingFinetuningDataset:
return self.streaming_dataset_class(*args, **kwargs)
dataset = self.streaming_dataset_class(*args, **kwargs)
num_canonical_nodes = dataset.num_canonical_nodes
num_samples = dataset.num_samples
if num_canonical_nodes is None:
num_physical_nodes = dist.get_world_size(
) // dist.get_local_world_size()
if num_samples < num_physical_nodes:
raise DatasetTooSmallError(
f'{num_samples=} is less than {dist.get_world_size() // dist.get_local_world_size()}, the number of physical nodes. ',
)

if num_canonical_nodes is not None and num_samples < num_canonical_nodes:
raise DatasetTooSmallError(
f'{num_samples=} is less than {num_canonical_nodes=}. ' +
'Please check your index.json file and ensure that your dataset has been written out correctly.'
+ 'If this was intended, reduce num_canonical_nodes.',
)
return dataset


dataset_constructor = DatasetConstructor()
Expand Down
6 changes: 3 additions & 3 deletions llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,9 @@ def __init__(self, dataset_name: str, split: str) -> None:
class DatasetTooSmallError(UserError):
"""Error thrown when the dataset is too small to be processed."""

def __init__(self) -> None:
message = f'Your dataset is too small and produced no complete samples during preprocessing. Please provide more data.'
super().__init__(message)
def __init__(self, reason: str) -> None:
message = f'Your dataset is too small and produced no complete samples or too few samples. Please provide more data. {reason}'
super().__init__(message, reason=reason)


class RunTimeoutError(InternalError):
Expand Down
44 changes: 44 additions & 0 deletions tests/data/test_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
from contextlib import nullcontext
from typing import Optional
from unittest import mock

import pytest

from llmfoundry.data.finetuning.tasks import dataset_constructor
from llmfoundry.utils.exceptions import DatasetTooSmallError


@pytest.mark.parametrize('num_canonical_nodes', [None, 8, 2])
def test_finetuning_streaming_dataset_too_small(
num_canonical_nodes: Optional[int],
):
num_samples = 2

class MockDataset:

def __init__(self):
self.num_canonical_nodes = num_canonical_nodes
self.num_samples = num_samples

class MockDist:

def get_world_size(self):
return 32

def get_local_world_size(self):
return 8

result_context = nullcontext(
) if num_canonical_nodes == 2 else pytest.raises(DatasetTooSmallError)
with result_context:
with mock.patch(
'llmfoundry.data.finetuning.tasks.dist',
new=MockDist(),
):
with mock.patch(
'llmfoundry.data.finetuning.tasks.DatasetConstructor.streaming_dataset_class',
new=MockDataset,
):
dataset_constructor.build_from_streaming()

0 comments on commit 5465db4

Please sign in to comment.