diff --git a/llmfoundry/command_utils/__init__.py b/llmfoundry/command_utils/__init__.py index 756e611a88..617f17a642 100644 --- a/llmfoundry/command_utils/__init__.py +++ b/llmfoundry/command_utils/__init__.py @@ -23,6 +23,10 @@ convert_text_to_mds, convert_text_to_mds_from_args, ) +from llmfoundry.command_utils.data_prep.split_eval_data_from_train_data import ( + split_eval_data_from_train_data_from_args, + split_examples, +) from llmfoundry.command_utils.eval import ( eval_from_yaml, evaluate, @@ -54,4 +58,6 @@ 'convert_text_to_mds_from_args', 'convert_delta_to_json_from_args', 'fetch_DT', + 'split_eval_data_from_train_data_from_args', + 'split_examples', ] diff --git a/llmfoundry/command_utils/data_prep/split_eval_data_from_train_data.py b/llmfoundry/command_utils/data_prep/split_eval_data_from_train_data.py new file mode 100644 index 0000000000..10a8537ed6 --- /dev/null +++ b/llmfoundry/command_utils/data_prep/split_eval_data_from_train_data.py @@ -0,0 +1,159 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +import re +import tempfile +from typing import Optional + +import composer.utils as utils +import numpy as np + +log = logging.getLogger(__name__) + +REMOTE_OBJECT_STORE_FILE_REGEX = re.compile( + r'^((s3|oci|gs):\/\/|dbfs:\/Volumes\/)[/a-zA-Z0-9 ()_\-.]+$', +) + + +def is_remote_object_store_file(data_path_folder: str) -> bool: + """Check if the provided data path is a remote object store file. + + Args: + data_path_folder (str): Path to the training dataset folder + + Returns: + bool: True if the data path is a remote object store file + """ + return REMOTE_OBJECT_STORE_FILE_REGEX.match(data_path_folder) is not None + + +def maybe_download_data_as_jsonl( + data_path_folder: str, + data_path_split: str, +) -> str: + """Prepares dataset as a local JSONL file. + + Downloads from remote object store if needed. + This function is intended to be invoked by DBX Finetuning. + Thus, it assumes the provided data is: + 1. A JSONL stored as a remote object store file (e.g. S3, OCI, GCS) + 2. A Delta table converted to JSONL at 'tmp-t/{data_path_split}-00000-of-00001.jsonl` + using the 'llmfoundry.scripts.convert_delta_to_json.py' script. + + Args: + data_path_folder (str): Path to the training dataset folder + data_path_split (str): Data split + + Returns: + str: Path to the training dataset + """ + TEMP_DIR = tempfile.mkdtemp() + + if is_remote_object_store_file(data_path_folder): + log.info( + f'Downloading dataset from remote object store: {data_path_folder}{data_path_split}.jsonl', + ) + remote_path = f'{data_path_folder}/{data_path_split}.jsonl' + data_path = os.path.join(TEMP_DIR, f'{data_path_split}.jsonl') + utils.get_file(remote_path, data_path, overwrite=True) + else: + log.info( + f'Dataset is converted from Delta table. Using local file {data_path_folder}', + ) + data_path = os.path.join( + data_path_folder, + f'{data_path_split}-00000-of-00001.jsonl', + ) + + if not os.path.exists(data_path): + raise FileNotFoundError( + f'Expected dataset file at {data_path} for splitting, but it does not exist.', + ) + + return data_path + + +def split_examples( + data_path: str, + output_path: str, + eval_split_ratio: float, + max_eval_samples: Optional[int] = None, + seed: Optional[int] = None, +) -> None: + """Splits the dataset into training and evaluation sets. + + Args: + data_path (str): Path to the training dataset (local jsonl file) + output_path (str): Directory to save the split dataset + eval_split_ratio (float): Ratio of the dataset to use for evaluation. The remainder will be used for training + max_eval_samples (int): Maximum number of samples to include in the eval set. If None, all eval_split_ratio * train_dataset_size samples will be used + seed (int): Random seed for splitting the dataset + """ + os.makedirs(output_path, exist_ok=True) + + # first pass: count total number of lines and determine sample size + total_lines = 0 + with open(data_path, 'r') as infile: + for _ in infile: + total_lines += 1 + sample_size = int(eval_split_ratio * total_lines) + if max_eval_samples is not None: + sample_size = min(sample_size, max_eval_samples) + + # Use a new RNG instance with the provided seed + rng = np.random.default_rng(seed) + random_numbers = rng.random(total_lines) + + # TODO: Consider using reservoir sampling for large datasets + # Jimmy doesn't think we need to do this right now, since we will + # migrate all of this splitting logic to workflows later anyways, so + # we can do it then + sample_indices = set(np.argsort(random_numbers)[:sample_size]) + + # second pass: sample indices + with open(data_path, 'r') as infile, open( + os.path.join(output_path, 'train.jsonl'), + 'w', + ) as train_outfile, open( + os.path.join(output_path, 'eval.jsonl'), + 'w', + ) as eval_outfile: + for idx, line in enumerate(infile): + if idx in sample_indices: + eval_outfile.write(line) + else: + train_outfile.write(line) + + log.info( + f'Split {data_path} into train set of size {total_lines - sample_size} and eval set of size {sample_size}.', + ) + + +def split_eval_data_from_train_data_from_args( + data_path_folder: str, + data_path_split: str, + output_path: str, + eval_split_ratio: float, + max_eval_samples: Optional[int] = None, + seed: Optional[int] = None, +) -> None: + """A wrapper for split_examples that parses arguments. + + Args: + data_path_folder (str): Path to the training dataset folder + data_path_split (str): Data split + output_path (str): Directory to save the split dataset + eval_split_ratio (float): Ratio of the dataset to use for evaluation. The remainder will be used for training + max_eval_samples (int): Maximum number of samples to include in the eval set. If None, all eval_split_ratio * train_dataset_size samples will be used + seed (int): Random seed for splitting the dataset + """ + data_path = maybe_download_data_as_jsonl(data_path_folder, data_path_split) + split_examples( + data_path, + output_path, + eval_split_ratio, + max_eval_samples, + seed, + ) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index b83aee3aa6..68bd1f5a23 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -713,6 +713,73 @@ def state_dict(self, num_samples: int, ) +def maybe_safe_download_hf_data( + dataset_name: str, + hf_kwargs: Optional[dict[str, Any]] = None, +) -> str: + """Download a HuggingFace dataset locally if it does not already exist. + + Args: + dataset_name (str): The name of the HuggingFace dataset to use. Can be a remote http(s) + directory or object store bucket containing the file {split}.jsonl. + hf_kwargs (dict, optional): Additional kwargs to pass to `datasets.load_dataset`. + + Returns: + str: The local path to the dataset. + """ + if hf_kwargs is None: + hf_kwargs = {} + + if not os.path.isdir(dataset_name): + # dataset_name is not a local dir path, download if needed. + local_dataset_dir = os.path.join( + tempfile.mkdtemp(), + dataset_name, + ) + + log.debug( + f'Downloading dataset {dataset_name} to {local_dataset_dir}.', + ) + + if _is_empty_or_nonexistent(dirpath=local_dataset_dir): + # Safely load the dataset from HF Hub with restricted file types. + hf_hub.snapshot_download( + dataset_name, + repo_type='dataset', + allow_patterns=['*' + ext for ext in SUPPORTED_EXTENSIONS], + token=hf_kwargs.get('token', None), + revision=hf_kwargs.get('revision', None), + local_dir_use_symlinks=False, + local_dir=local_dataset_dir, + ) + if _is_empty_or_nonexistent(dirpath=local_dataset_dir): + log.error('Failed to safely load the dataset from HF Hub.') + raise InvalidFileExtensionError( + dataset_name, + SUPPORTED_EXTENSIONS, + ) + # Set dataset_name to the downloaded location. + dataset_name = local_dataset_dir + + # Ensure dataset_name is a local directory path (using abspath to avoid confusion). + dataset_name = os.path.abspath(dataset_name) + + # Check that the directory contains only allowed file types. + dataset_files = [f for _, _, files in os.walk(dataset_name) for f in files] + if not all( + Path(f).suffix in SUPPORTED_EXTENSIONS + + HUGGINGFACE_FOLDER_EXTENSIONS or f == '.gitignore' + for f in dataset_files + ): + log.error(f'Invalid file extension found in dataset during safe load.') + raise InvalidFileExtensionError( + dataset_name, + SUPPORTED_EXTENSIONS, + ) + + return dataset_name + + class DatasetConstructor: def __init__(self): @@ -910,54 +977,10 @@ def build_from_hf( filtered_dataset = None try: if safe_load: - if not os.path.isdir(dataset_name): - # dataset_name is not a local dir path, download if needed. - local_dataset_dir = os.path.join( - tempfile.mkdtemp(), - dataset_name, - ) - - log.debug( - f'Downloading dataset {dataset_name} to {local_dataset_dir}.', - ) - - if _is_empty_or_nonexistent(dirpath=local_dataset_dir): - # Safely load a dataset from HF Hub with restricted file types. - hf_hub.snapshot_download( - dataset_name, - repo_type='dataset', - allow_patterns=[ - '*' + ext for ext in SUPPORTED_EXTENSIONS - ], - token=hf_kwargs.get('token', None), - revision=hf_kwargs.get('revision', None), - local_dir_use_symlinks=False, - local_dir=local_dataset_dir, - ) - if _is_empty_or_nonexistent(dirpath=local_dataset_dir): - raise InvalidFileExtensionError( - dataset_name, - SUPPORTED_EXTENSIONS, - ) - # Set dataset_name to the downloaded location. - dataset_name = local_dataset_dir - - # dataset_name is a local dir path. Use the abspath to prevent confusion. - dataset_name = os.path.abspath(dataset_name) - - # Ensure that the local dir contains only allowed file types. - dataset_files = [ - f for _, _, files in os.walk(dataset_name) for f in files - ] - if not all( - Path(f).suffix in SUPPORTED_EXTENSIONS + - HUGGINGFACE_FOLDER_EXTENSIONS or f == '.gitignore' - for f in dataset_files - ): - raise InvalidFileExtensionError( - dataset_name, - SUPPORTED_EXTENSIONS, - ) + dataset_name = maybe_safe_download_hf_data( + dataset_name, + hf_kwargs, + ) dataset = hf_datasets.load_dataset( dataset_name, diff --git a/scripts/data_prep/split_eval_data_from_train_data.py b/scripts/data_prep/split_eval_data_from_train_data.py new file mode 100644 index 0000000000..20e248cdfd --- /dev/null +++ b/scripts/data_prep/split_eval_data_from_train_data.py @@ -0,0 +1,61 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from argparse import ArgumentParser + +from llmfoundry.command_utils import split_eval_data_from_train_data_from_args + +if __name__ == '__main__': + parser = ArgumentParser( + description='Split training dataset into train and eval sets', + ) + parser.add_argument( + '--data_path_folder', + required=True, + type=str, + help='Path to the training dataset folder', + ) + parser.add_argument( + '--data_path_split', + required=True, + type=str, + help='Path to the training dataset split', + ) + parser.add_argument( + '--output_path', + required=False, + type=str, + default='/tmp-split', + help='Path to save the split dataset', + ) + parser.add_argument( + '--eval_split_ratio', + required=False, + type=float, + default=0.1, + help= + 'Ratio of the dataset to use for evaluation. The remainder will be used for training', + ) + parser.add_argument( + '--max_eval_samples', + required=False, + type=int, + default=100, + help='Maximum number of samples to include in the eval set', + ) + parser.add_argument( + '--seed', + required=False, + type=int, + default=42, + help='Random seed for splitting the dataset', + ) + args = parser.parse_args() + split_eval_data_from_train_data_from_args( + data_path_folder=args.data_path_folder, + data_path_split=args.data_path_split, + output_path=args.output_path, + eval_split_ratio=args.eval_split_ratio, + max_eval_samples=args.max_eval_samples, + seed=args.seed, + ) diff --git a/tests/a_scripts/data_prep/test_split_eval_set.py b/tests/a_scripts/data_prep/test_split_eval_set.py new file mode 100644 index 0000000000..7f9a50b351 --- /dev/null +++ b/tests/a_scripts/data_prep/test_split_eval_set.py @@ -0,0 +1,234 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import hashlib +import json +import os +from unittest.mock import patch + +import pytest + +from llmfoundry.command_utils import ( + split_eval_data_from_train_data_from_args, + split_examples, +) +from llmfoundry.command_utils.data_prep.split_eval_data_from_train_data import ( + REMOTE_OBJECT_STORE_FILE_REGEX, + is_remote_object_store_file, +) + +# Default values +OUTPUT_DIR = 'tmp-split' +TMPT_DIR = 'tmp-t' +DATA_PATH_SPLIT = 'train' +EVAL_SPLIT_RATIO = 0.1 +DEFAULT_FILE = TMPT_DIR + '/train-00000-of-00001.jsonl' + + +@pytest.mark.parametrize( + 'test_input, expected', + [ + ('s3://bucket-name/path/to/file', True), + ('oci://bucket-name/path/to/file', True), + ('gs://bucket-name/path/to/file', True), + ('dbfs:/Volumes/path/to/file', True), + ('s3://bucket-name/path/to/file with spaces', True), + ('https://bucket-name/path/to/file', False), + ('/local/path/to/file', False), + ('s3:/bucket-name/path/to/file', False), + ('s3://bucket-name/path/to/file?', False), + ], +) +def test_remote_object_store_file_regex( + test_input: str, + expected: bool, +) -> None: + """Test the regex pattern for remote object store file paths.""" + assert bool(REMOTE_OBJECT_STORE_FILE_REGEX.match(test_input)) == expected + + +@pytest.mark.parametrize( + 'test_input, expected', + [ + ('s3://bucket-name/path/to/file', True), + ('oci://bucket-name/path/to/file', True), + ('gs://bucket-name/path/to/file', True), + ('dbfs:/Volumes/path/to/file', True), + ('s3://bucket-name/path/to/file with spaces', True), + ('https://bucket-name/path/to/file', False), + ('/local/path/to/dir', False), + ('s3:/bucket-name/path/to/file', False), + ('s3://bucket-name/path/to/file?', False), + ], +) +def test_is_remote_object_store_file(test_input: str, expected: bool) -> None: + """Test the is_remote_object_store_file function.""" + assert is_remote_object_store_file(test_input) == expected + + +def calculate_file_hash(filepath: str) -> str: + with open(filepath, 'rb') as f: + file_hash = hashlib.sha256(f.read()).hexdigest() + return file_hash + + +def count_lines(filepath: str) -> int: + with open(filepath, 'r') as f: + return sum(1 for _ in f) + + +@pytest.fixture(scope='module', autouse=True) +def setup_and_teardown_module(): + # Setup: create local testing file + os.makedirs(TMPT_DIR, exist_ok=True) + with open(DEFAULT_FILE, 'w') as f: + for i in range(1000): + f.write( + json.dumps({ + 'prompt': 'hello world ' + str(i), + 'response': 'hi you!', + }) + '\n', + ) + yield + + # Teardown: clean up output and tmp directories + os.system(f'rm -rf {OUTPUT_DIR}') + os.system(f'rm -rf {TMPT_DIR}') + + +def test_basic_split(): + """Test basic functionality on local file.""" + output_path = os.path.join(OUTPUT_DIR, 'basic-test') + split_eval_data_from_train_data_from_args( + TMPT_DIR, + DATA_PATH_SPLIT, + output_path, + EVAL_SPLIT_RATIO, + ) + assert os.path.isfile(os.path.join(output_path, 'train.jsonl')) + assert os.path.isfile(os.path.join(output_path, 'eval.jsonl')) + + +def test_basic_split_output_exists(): + """Test that split overwrites existing files in output directory.""" + output_path = os.path.join(OUTPUT_DIR, 'basic-test') + os.makedirs(output_path, exist_ok=True) + train_file = os.path.join(output_path, 'train.jsonl') + eval_file = os.path.join(output_path, 'eval.jsonl') + with open(train_file, 'w') as f: + f.write('existing file train') + with open(eval_file, 'w') as f: + f.write('existing file eval') + old_train_hash = calculate_file_hash(train_file) + old_eval_hash = calculate_file_hash(eval_file) + split_eval_data_from_train_data_from_args( + TMPT_DIR, + DATA_PATH_SPLIT, + output_path, + EVAL_SPLIT_RATIO, + ) + assert calculate_file_hash(train_file) != old_train_hash + assert calculate_file_hash(eval_file) != old_eval_hash + + +def test_max_eval_samples(): + """Test case where max_eval_samples < eval_split_ratio * total samples""" + output_path = os.path.join(OUTPUT_DIR, 'max-eval-test') + max_eval_samples = 50 + split_eval_data_from_train_data_from_args( + TMPT_DIR, + DATA_PATH_SPLIT, + output_path, + EVAL_SPLIT_RATIO, + max_eval_samples, + ) + eval_lines = count_lines(os.path.join(output_path, 'eval.jsonl')) + assert eval_lines == max_eval_samples + + +def test_eval_split_ratio(): + """Test case where max_eval_samples is not used.""" + output_path = os.path.join(OUTPUT_DIR, 'eval-split-test') + split_eval_data_from_train_data_from_args( + TMPT_DIR, + DATA_PATH_SPLIT, + output_path, + EVAL_SPLIT_RATIO, + ) + original_data_lines = count_lines(DEFAULT_FILE) + eval_lines = count_lines(os.path.join(output_path, 'eval.jsonl')) + assert abs( + eval_lines - EVAL_SPLIT_RATIO * original_data_lines, + ) < 1 # allow for rounding errors + + +def test_seed_consistency(): + """Test if the same seed generates consistent splits.""" + output_path_1 = os.path.join(OUTPUT_DIR, 'seed-test-1') + output_path_2 = os.path.join(OUTPUT_DIR, 'seed-test-2') + split_examples(DEFAULT_FILE, output_path_1, EVAL_SPLIT_RATIO, seed=12345) + split_examples(DEFAULT_FILE, output_path_2, EVAL_SPLIT_RATIO, seed=12345) + train_hash_1 = calculate_file_hash( + os.path.join(output_path_1, 'train.jsonl'), + ) + train_hash_2 = calculate_file_hash( + os.path.join(output_path_2, 'train.jsonl'), + ) + eval_hash_1 = calculate_file_hash(os.path.join(output_path_1, 'eval.jsonl')) + eval_hash_2 = calculate_file_hash(os.path.join(output_path_2, 'eval.jsonl')) + + assert train_hash_1 == train_hash_2 + assert eval_hash_1 == eval_hash_2 + + output_path_3 = os.path.join(OUTPUT_DIR, 'seed-test-3') + split_examples(DEFAULT_FILE, output_path_3, EVAL_SPLIT_RATIO, seed=54321) + train_hash_3 = calculate_file_hash( + os.path.join(output_path_3, 'train.jsonl'), + ) + eval_hash_3 = calculate_file_hash(os.path.join(output_path_3, 'eval.jsonl')) + + assert train_hash_1 != train_hash_3 + assert eval_hash_1 != eval_hash_3 + + +def _mock_get_file(remote_path: str, data_path: str, overwrite: bool): + with open(data_path, 'w') as f: + for i in range(1000): + f.write( + json.dumps({ + 'prompt': 'hello world ' + str(i), + 'response': 'hi you!', + }) + '\n', + ) + + +def test_remote_store_data_split(): + """Test splitting a dataset from a remote store.""" + output_path = os.path.join(OUTPUT_DIR, 'remote-split-test') + with patch( + 'composer.utils.get_file', + side_effect=_mock_get_file, + ) as mock_get_file: + split_eval_data_from_train_data_from_args( + 'dbfs:/Volumes/test/test/test.jsonl', + 'unique-split-name', + output_path, + EVAL_SPLIT_RATIO, + ) + mock_get_file.assert_called() + + assert os.path.isfile(os.path.join(output_path, 'train.jsonl')) + assert os.path.isfile(os.path.join(output_path, 'eval.jsonl')) + assert count_lines(os.path.join(output_path, 'train.jsonl')) > 0 + assert count_lines(os.path.join(output_path, 'eval.jsonl')) > 0 + + +def test_missing_delta_file_error(): + # expects file 'TMPT_DIR/missing-00000-of-00001.jsonl + with pytest.raises(FileNotFoundError): + split_eval_data_from_train_data_from_args( + TMPT_DIR, + 'missing', + OUTPUT_DIR, + EVAL_SPLIT_RATIO, + )