From c308d1026a152c2990a64868d53fc33b8360a200 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Fri, 15 Sep 2023 14:41:09 -0700 Subject: [PATCH] Add script for MDS conversion of bucket of text files (#570) --- llmfoundry/utils/data_prep_utils.py | 111 ++++++ scripts/data_prep/convert_text_to_mds.py | 440 +++++++++++++++++++++++ tests/test_convert_text_to_mds.py | 209 +++++++++++ 3 files changed, 760 insertions(+) create mode 100644 llmfoundry/utils/data_prep_utils.py create mode 100644 scripts/data_prep/convert_text_to_mds.py create mode 100644 tests/test_convert_text_to_mds.py diff --git a/llmfoundry/utils/data_prep_utils.py b/llmfoundry/utils/data_prep_utils.py new file mode 100644 index 0000000000..75e27b504f --- /dev/null +++ b/llmfoundry/utils/data_prep_utils.py @@ -0,0 +1,111 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import json +import os +from glob import glob +from typing import List, Optional + +from composer.utils import ObjectStore + + +def with_id(basename: str, shard_id: int) -> str: + """Get a new basename with the given shard_id. + + From https://github.com/mosaicml/streaming/blob/main/examples/multiprocess_dataset_conversion.ipynb. + + Args: + basename (str): Old basename of file. + shard_id (int): New shard ID. + + Returns: + str: New basename of file. + """ + parts = basename.split('.') + parts[1] = f'{shard_id:05}' + return '.'.join(parts) + + +def merge_shard_groups(root: str) -> None: + """Merge ephemeral sub-datasets created in parallel into one dataset. + + From https://github.com/mosaicml/streaming/blob/main/examples/multiprocess_dataset + _conversion.ipynb. + + Args: + root (str): Root directory. + """ + pattern = os.path.join(root, '*') + subdirs = sorted(glob(pattern)) + shard_id = 0 + infos = [] + for subdir in subdirs: + index_filename = os.path.join(subdir, 'index.json') + with open(index_filename) as index_file: + obj = json.load(index_file) + for info in obj['shards']: + old_basename = info['raw_data']['basename'] + new_basename = with_id(old_basename, shard_id) + info['raw_data']['basename'] = new_basename + + if info['zip_data'] is not None: + old_basename = info['zip_data']['basename'] + new_basename = with_id(old_basename, shard_id) + info['zip_data']['basename'] = new_basename + + old_filename = os.path.join(subdir, old_basename) + new_filename = os.path.join(root, new_basename) + os.rename(old_filename, new_filename) + + shard_id += 1 + infos.append(info) + + os.remove(index_filename) + os.rmdir(subdir) + + index_filename = os.path.join(root, 'index.json') + obj = { + 'version': 2, + 'shards': infos, + } + text = json.dumps(obj, sort_keys=True) + with open(index_filename, 'w') as out: + out.write(text) + + +class DownloadingIterable: + + def __init__( + self, + object_names: List[str], + output_folder: str, + object_store: Optional[ObjectStore], + ): + """Iterable that downloads files from an object store before yielding. + + If object_store is None, input_folder_prefix is treated as a local path. + + Args: + object_names (List[str]): Names of objects to download + output_folder (str): Local folder to write downloaded files to + object_store (Optiona[ObjectStore]): Object store to download from + """ + self.object_names = object_names + self.object_store = object_store + self.output_folder = output_folder + + def __iter__(self): + for object_name in self.object_names: + object_name = object_name.strip('/') + output_filename = os.path.join(self.output_folder, object_name) + if self.object_store is not None: + self.object_store.download_object(object_name=object_name, + filename=output_filename, + overwrite=True) + else: + # Inputs are local so we do not need to download them. + output_filename = object_name + + with open(output_filename) as _txt_file: + txt = _txt_file.read() + yield {'text': txt} diff --git a/scripts/data_prep/convert_text_to_mds.py b/scripts/data_prep/convert_text_to_mds.py new file mode 100644 index 0000000000..5e37da639a --- /dev/null +++ b/scripts/data_prep/convert_text_to_mds.py @@ -0,0 +1,440 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import logging +import math +import os +import tempfile +from argparse import ArgumentParser, Namespace +from concurrent.futures import ProcessPoolExecutor +from glob import glob +from typing import Iterable, List, Tuple, cast + +from composer.utils import (ObjectStore, maybe_create_object_store_from_uri, + parse_uri) +from streaming import MDSWriter +from tqdm import tqdm +from transformers import AutoTokenizer + +from llmfoundry.data import ConcatTokensDataset +from llmfoundry.utils.data_prep_utils import (DownloadingIterable, + merge_shard_groups) + +log = logging.getLogger(__name__) +DONE_FILENAME = '.text_to_mds_conversion_done' + + +def parse_args() -> Namespace: + """Parse commandline arguments.""" + parser = ArgumentParser( + description= + 'Convert text files into MDS format, optionally concatenating and tokenizing', + ) + parser.add_argument( + '--output_folder', + type=str, + required=True, + help='The folder to write output to', + ) + parser.add_argument( + '--input_folder', + type=str, + required=True, + help='The folder with text files to convert to mds', + ) + parser.add_argument( + '--compression', + type=str, + default='zstd', + help='The compression algorithm to use for MDS writing', + ) + + parser.add_argument( + '--concat_tokens', + type=int, + help='Convert text to tokens and concatenate up to this many tokens', + ) + + parser.add_argument( + '--tokenizer', + type=str, + help='The name of the tokenizer to use', + ) + parser.add_argument( + '--bos_text', + type=str, + required=False, + default=None, + help= + 'The text to prepend to each example to separate concatenated examples', + ) + parser.add_argument( + '--eos_text', + type=str, + required=False, + default=None, + help= + 'The text to append to each example to separate concatenated examples', + ) + parser.add_argument( + '--no_wrap', + default=False, + action='store_true', + help= + 'Whether to let text examples wrap across multiple training examples', + ) + parser.add_argument( + '--processes', + type=int, + required=False, + default=1, + help= + 'The number of processes to use to download and convert the dataset', + ) + parser.add_argument( + '--reprocess', + type=bool, + required=False, + default=False, + help='If true, reprocess the input_folder to mds format. Otherwise, ' + + 'only reprocess upon changes to the input folder or dataset creation parameters.', + ) + + parsed = parser.parse_args() + + # Make sure we have needed concat options + if (parsed.concat_tokens is not None and + isinstance(parsed.concat_tokens, int) and parsed.tokenizer is None): + parser.error( + 'When setting --concat_tokens, you must specify a --tokenizer') + + # now that we have validated them, change BOS/EOS to strings + if parsed.bos_text is None: + parsed.bos_text = '' + if parsed.eos_text is None: + parsed.eos_text = '' + return parsed + + +def get_object_names(input_folder: str) -> List[str]: + """Get object names from a local or remote folder. + + Args: + input_folder (str): local or remote folder path. + """ + object_store = maybe_create_object_store_from_uri(input_folder) + if object_store is not None: + _, _, folder_prefix = parse_uri(input_folder) + names = [ + name for name in object_store.list_objects(folder_prefix) + if name.endswith('.txt') + ] + else: + # input_folder is a local folder + names = [ + text_file for dirpath, _, _ in os.walk(input_folder) + for text_file in glob(os.path.join(dirpath, '*.txt')) + ] + # return names, sizes + log.info(f'Found {len(names)} text files at {input_folder}') + + return names + + +def get_task_args( + object_names: List[str], + output_root: str, + input_folder: str, + n_groups: int, + tokenizer_name: str, + concat_tokens: int, + eos_text: str, + bos_text: str, + no_wrap: bool, + compression: str, +) -> Iterable: + """Get download_and_convert arguments split across n_groups. + + Each group handles a portion of object_names. + + Args: + object_names (List[str]): Names of objects to process + output_root (str): Folder to write MDS shards to + input_folder (str): Folder of text files to process + n_groups (int): Number of groups to split the object names into + tokenizer_name (str): Name of tokenizer to use + concat_tokens (int): Concantenate up to this many tokens + eos_text (str): Textend to append to each example to separate concatenated samples + bos_text (str): Text to prepend to each example to separate concatenated samples + no_wrap: (bool): Whether to let text examples wrap across multiple training examples + compression (str): The compression algorithm to use for MDS writing + """ + num_objects = len(object_names) + objs_per_group = math.ceil(num_objects / n_groups) + for group, i in enumerate(range(0, num_objects, objs_per_group)): + output_subdir = os.path.join(output_root, str(group)) + yield ( + object_names[i:min(i + objs_per_group, num_objects)], + output_subdir, + input_folder, + tokenizer_name, + concat_tokens, + eos_text, + bos_text, + no_wrap, + compression, + ) + + +def download_and_convert_starargs(args: Tuple): + """Helper function to call download_and_convert with star args. + + This helps us use download_and_convert with mutiprocessing. + """ + return download_and_convert(*args) + + +def download_and_convert( + file_names: List[str], + output_folder: str, + input_folder: str, + tokenizer_name: str, + concat_tokens: int, + eos_text: str, + bos_text: str, + no_wrap: bool, + compression: str, +): + """Downloads and converts text fies to MDS format. + + Args: + file_names (List[str]): Files to process + output_folder (str): Folder to write MDS shards to + input_folder (str): Folder of text files to process + tokenizer_name (str): Name of tokenizer to use + concat_tokens (int): Concantenate up to this many tokens + eos_text (str): Textend to append to each example to separate concatenated samples + bos_text (str): Text to prepend to each example to separate concatenated samples + no_wrap: (bool): Whether to let text examples wrap across multiple training examples + compression (str): The compression algorithm to use for MDS writing + """ + object_store = maybe_create_object_store_from_uri(input_folder) + + # Download file_names + with tempfile.TemporaryDirectory() as tmp_dir: + downloading_iter = DownloadingIterable(object_names=file_names, + output_folder=tmp_dir, + object_store=object_store) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenizer.model_max_length = 5000000000 # Hack to prevent warnings from HuggingFace + + # Use the ConcatTokensDataset from LLM-foundry to concatenate sequences of tokens up + # to the maximum sequence length + dataset = ConcatTokensDataset( + hf_dataset=downloading_iter, + max_length=concat_tokens, + tokenizer=tokenizer, + eos_text=eos_text, + bos_text=bos_text, + no_wrap=no_wrap, + ) + + columns = {'tokens': 'bytes'} + + log.info('Converting to MDS format...') + with MDSWriter(out=output_folder, + columns=columns, + compression=compression) as out: + for sample in tqdm(dataset): + out.write(sample) + + +def is_remote_path(path: str) -> bool: + """Checks whether a path is a remote path. + + Args: + path (str): path to check + """ + backend, bucket, _ = parse_uri(path) + return backend != '' and bucket != '' + + +def is_already_processed(output_root: str, args_str: str, + object_names: List[str]) -> bool: + """Determines whether a group of text files has already been processed. + + Checks the done fie at output root to determine this. + + Args: + output_root (str): Output folder where a done file may exist + args_str (str): String representation of the arguments + object_names (List[str]): Names of objects to convert to MDS format + """ + # Retrieve the done file contents + output_object_store = maybe_create_object_store_from_uri(output_root) + if output_object_store is not None: + # Download and read the done file from the remote object store + _, _, output_folder_prefix = parse_uri(output_root) + try: + with tempfile.TemporaryDirectory() as tmp_dir: + done_file = os.path.join(tmp_dir, DONE_FILENAME) + output_object_store.download_object( + os.path.join(output_folder_prefix, DONE_FILENAME), + done_file) + with open(done_file) as df: + done_file_contents = df.read().splitlines() + except FileNotFoundError: + return False + else: + # Read the local done file + done_file = os.path.join(output_root, DONE_FILENAME) + if not os.path.isfile(done_file): + return False + with open(done_file) as df: + done_file_contents = df.read().splitlines() + # Compare the arguments + prev_args_str = done_file_contents[0] + if prev_args_str != args_str: + return False + + # Compare file names + prev_names = done_file_contents[1:] + if len(prev_names) != len(object_names): + return False + for idx, prev_name in enumerate(prev_names): + if object_names[idx] != prev_name: + return False + return True + + +def write_done_file(folder: str, args_str: str, object_names: List[str]): + """Write a file to signify completion. + + This the done file includes the arguments to processing and + a list of objects that were processed. + + Args: + folder (str): Folder to write the done file to + args_str (str): String representation of arguments + object_names (List[str]): List of objects to convert to MDS format + """ + with open(os.path.join(folder, DONE_FILENAME), 'w') as done_file: + done_file.write('\n'.join([args_str] + object_names) + '\n') + + +def convert_text_to_mds( + tokenizer_name: str, + output_folder: str, + input_folder: str, + concat_tokens: int, + eos_text: str, + bos_text: str, + no_wrap: bool, + compression: str, + processes: int, + args_str: str, + reprocess: bool, +): + """Convert a folder of text files to MDS format. + + Args: + tokenizer_name (str): Name of tokenizer to use + output_folder (str): Folder to write MDS shards to + input_folder (str): Folder of text files to process + concat_tokens (int): Concantenate up to this many tokens + eos_text (str): Textend to append to each example to separate concatenated samples + bos_text (str): Text to prepend to each example to separate concatenated samples + no_wrap: (bool): Whether to let text examples wrap across multiple training examples + compression (str): The compression algorithm to use for MDS writing + processes (int): The number of processes to use. + args_str (str): String representation of the arguments + reprocess (bool): Whether to always reprocess the given folder of text files + """ + is_remote_output = is_remote_path(output_folder) + + object_names = get_object_names(input_folder) + if len(object_names) == 0: + raise ValueError(f'No text files were found at {input_folder}.') + + # Check if the text files in the bucket have already been processed. + if not reprocess and is_already_processed(output_folder, args_str, + object_names): + log.info( + f'Input folder {input_folder} is already processed at {output_folder} and ' + + + 'reprocess is set to False. Set reprocess to True if you would like to force reprocessing.' + ) + return + + # Use a temporary local directory if the output is remote and there are more than 1 processes + local_output_folder = tempfile.TemporaryDirectory( + ).name if is_remote_output else output_folder + + if processes > 1: + # Download and convert the text files in parallel + args = get_task_args(object_names, local_output_folder, input_folder, + processes, tokenizer_name, concat_tokens, eos_text, + bos_text, no_wrap, compression) + with ProcessPoolExecutor(max_workers=processes) as executor: + list(executor.map(download_and_convert_starargs, args)) + + # Merge the mds shards from each of the processes into a single folder + merge_shard_groups(local_output_folder) + else: + download_and_convert(object_names, local_output_folder, input_folder, + tokenizer_name, concat_tokens, eos_text, bos_text, + no_wrap, compression) + + # Write a done file with the args and object names + write_done_file(local_output_folder, args_str, object_names) + + if is_remote_output: + # Upload the local output to the remote location + output_object_store = cast( + ObjectStore, maybe_create_object_store_from_uri(output_folder)) + _, _, output_folder_prefix = parse_uri(output_folder) + files_to_upload = os.listdir(local_output_folder) + + for file in files_to_upload: + assert not os.path.isdir(file) + remote_path = os.path.join(output_folder_prefix, file) + output_object_store.upload_object( + remote_path, os.path.join(local_output_folder, file)) + + +def _args_str(original_args: Namespace) -> str: + """Create a string from the args to determine whether to reprocess. + + Args: + original_args (Namespace): Arguments to main function. + """ + # Take the arguments that influence the final result. + # reprocess and max_mds_writer_workers are not taken. + args = Namespace( + tokenizer_name=original_args.tokenizer, + output_folder=original_args.output_folder, + input_folder=original_args.input_folder, + concat_tokens=original_args.concat_tokens, + eos_text=original_args.eos_text, + bos_text=original_args.bos_text, + no_wrap=original_args.no_wrap, + compression=original_args.compression, + processes=original_args.processes, + ) + + return str(args) + + +if __name__ == '__main__': + args = parse_args() + convert_text_to_mds(tokenizer_name=args.tokenizer, + output_folder=args.output_folder, + input_folder=args.input_folder, + concat_tokens=args.concat_tokens, + eos_text=args.eos_text, + bos_text=args.bos_text, + no_wrap=args.no_wrap, + compression=args.compression, + processes=args.processes, + reprocess=args.reprocess, + args_str=_args_str(args)) diff --git a/tests/test_convert_text_to_mds.py b/tests/test_convert_text_to_mds.py new file mode 100644 index 0000000000..2d4878ebbb --- /dev/null +++ b/tests/test_convert_text_to_mds.py @@ -0,0 +1,209 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import os +import sys + +import pytest + +# Add repo root to path so we can import scripts and test it +repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.append(repo_dir) +import pathlib +from concurrent.futures import ProcessPoolExecutor +from glob import glob +from typing import Callable, Iterable, List +from unittest.mock import Mock, patch + +import numpy as np +from streaming import StreamingDataset +from transformers import AutoTokenizer + +from scripts.data_prep.convert_text_to_mds import (DONE_FILENAME, + convert_text_to_mds, + download_and_convert, + is_already_processed, + merge_shard_groups, + write_done_file) + + +class MockObjectStore(): + + def __init__(self, remote_folder: str, n_text_files: int, + text_content: str): + os.makedirs(remote_folder, exist_ok=True) + for i in range(n_text_files): + with open(os.path.join(remote_folder, f'test{i}.txt'), 'w') as f: + f.write(text_content) + + self.remote_folder = remote_folder + self.n_text_files = n_text_files + + def download_object(self, + object_name: str, + filename: str, + overwrite: bool = False): + dirname = os.path.dirname(filename) + if dirname: + os.makedirs(dirname, exist_ok=True) + with open( + os.path.join(self.remote_folder, os.path.basename(object_name)), + 'rb') as remote_file, open(filename, 'wb') as local_file: + local_file.write(remote_file.read()) + + def list_objects(self, prefix: str) -> List[str]: + return glob(os.path.join(self.remote_folder, '*.txt')) + + def upload_object(self, object_name: str, filename: str): + with open( + os.path.join(self.remote_folder, os.path.basename(object_name)), + 'wb') as remote_file, open(filename, 'rb') as local_file: + remote_file.write(local_file.read()) + + +def _call_convert_text_to_mds(processes: int, tokenizer_name: str, + concat_tokens: int) -> None: + convert_text_to_mds( + tokenizer_name=tokenizer_name, + output_folder=f's3://fake-test-output-path', + input_folder=f's3://fake-test-input-path', + concat_tokens=concat_tokens, + eos_text='', + bos_text='', + no_wrap=False, + compression='zstd', + processes=processes, + args_str='Namespace()', + reprocess=False, + ) + + +# Mock starmap with no multiprocessing +def _mock_map(func: Callable, args: Iterable) -> Iterable: + for arg in args: + yield func(arg) + + +def _assert_files_exist(prefix: str, files: List[str]): + for file in files: + assert os.path.exists(os.path.join(prefix, file)) + + +@pytest.mark.parametrize('processes', [1, 2, 3]) +@patch.object(ProcessPoolExecutor, 'map', new=Mock(wraps=_mock_map)) +@patch( + 'scripts.data_prep.convert_text_to_mds.maybe_create_object_store_from_uri') +@patch('scripts.data_prep.convert_text_to_mds.parse_uri') +@patch('scripts.data_prep.convert_text_to_mds.download_and_convert', + wraps=download_and_convert) +@patch('scripts.data_prep.convert_text_to_mds.merge_shard_groups', + wraps=merge_shard_groups) +def test_single_and_multi_process(merge_shard_groups: Mock, + download_and_convert: Mock, parse_uri: Mock, + maybe_create_object_store_from_uri: Mock, + tmp_path: pathlib.Path, processes: int): + remote_folder = os.path.join(tmp_path, 'remote') + text_content = 'HELLO WORLD ' * 500 + tokenizer_name = 'mosaicml/mpt-7b' + n_text_files = processes * 3 + concat_tokens = 2048 + + mock_object_store = Mock( + wraps=MockObjectStore(remote_folder, n_text_files, text_content)) + maybe_create_object_store_from_uri.return_value = mock_object_store + parse_uri.return_value = ('s3', 'fake-test-bucket', str(remote_folder)) + + _call_convert_text_to_mds(processes=processes, + tokenizer_name=tokenizer_name, + concat_tokens=concat_tokens) + + # Check call counts + assert download_and_convert.call_count == processes # called once per process + assert mock_object_store.download_object.call_count == n_text_files + 1 # text files + done file + assert mock_object_store.upload_object.call_count == processes + 2 # shard per process + done file + index.json + + if processes > 1: + merge_shard_groups.assert_called_once() + + total_object_names = 0 + for call_args in download_and_convert.call_args_list: + object_names = call_args[0][0] + total_object_names += len(object_names) + + assert total_object_names == n_text_files # We should have processed all the text files + + # Check that correct output files exist + shards = [f'shard.0000{i}.mds.zstd' for i in range(processes)] + _assert_files_exist(prefix=remote_folder, + files=['index.json', DONE_FILENAME] + shards) + + _call_convert_text_to_mds(processes=processes, + tokenizer_name=tokenizer_name, + concat_tokens=concat_tokens) + + # Check call counts + assert download_and_convert.call_count == processes # No changes because we shoudn't reprocess + assert mock_object_store.download_object.call_count == n_text_files + 2 # One more done file is downloaded + assert mock_object_store.upload_object.call_count == processes + 2 # No changes + + # Create an extra text file and call again. + n_text_files += 1 + object_store = MockObjectStore(remote_folder, n_text_files, text_content) + mock_object_store = Mock(wraps=object_store) + maybe_create_object_store_from_uri.return_value = mock_object_store + + _call_convert_text_to_mds(processes=processes, + tokenizer_name=tokenizer_name, + concat_tokens=concat_tokens) + + # Check call counts + assert download_and_convert.call_count == processes * 2 # called once per process + assert mock_object_store.download_object.call_count == n_text_files + 1 # text files + done file + assert mock_object_store.upload_object.call_count == processes + 2 # shard per process + done file + index.json + + # Compute the expected number of tokens + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokens_per_file = len(tokenizer(text_content)['input_ids']) + files_per_process = [n_text_files // processes + ] * processes # Distrubte the files equally + files_per_process[ + 0] += n_text_files % processes # Give one of the processes the remainder + # expected number of tokens accounts for last tokens dropped by ConcatTokensDataset + expected_n_tokens = sum([ + ((n_files * tokens_per_file) // concat_tokens) * concat_tokens + for n_files in files_per_process + ]) + + dataset = StreamingDataset(local=remote_folder, num_canonical_nodes=1) + n_tokens = 0 + for i in range(dataset.num_samples): + sample = dataset[i] + tokens = np.frombuffer(sample['tokens'], dtype=int) + if i == 0: # For the first sample, check that the decoded sample matches the text_content + decoded = tokenizer.decode(tokens) + assert decoded == text_content[:len(decoded)] + n_tokens += len(tokens) + + # Check that the number of tokens found while iterating through the dataset is as expected. + assert n_tokens == expected_n_tokens + + +def test_is_already_processed(tmp_path: pathlib.Path): + tmp_path_str = str(tmp_path) + args_str = 'Namespace(x = 5)' + object_names = ['test0.txt', 'test1.txt'] + + assert not is_already_processed(tmp_path_str, args_str, + object_names) # Done file doesn't exist + + write_done_file(tmp_path_str, args_str, object_names) + assert is_already_processed(tmp_path_str, args_str, + object_names) # Args and names match + + write_done_file(tmp_path_str, args_str, object_names + ['test2.txt']) + assert not is_already_processed(tmp_path_str, args_str, + object_names) # Object names differ + + write_done_file(tmp_path_str, 'Namespace()', object_names) + assert not is_already_processed(tmp_path_str, args_str, + object_names) # Argument strings differ