diff --git a/llmfoundry/utils/data_prep_utils.py b/llmfoundry/utils/data_prep_utils.py index 75e27b504f..a88e65ee94 100644 --- a/llmfoundry/utils/data_prep_utils.py +++ b/llmfoundry/utils/data_prep_utils.py @@ -96,15 +96,16 @@ def __init__( def __iter__(self): for object_name in self.object_names: - object_name = object_name.strip('/') - output_filename = os.path.join(self.output_folder, object_name) + # Default output_filename, used for local paths. + output_filename = object_name + + # Download objects if remote path. if self.object_store is not None: + output_filename = os.path.join(self.output_folder, + object_name.strip('/')) self.object_store.download_object(object_name=object_name, filename=output_filename, overwrite=True) - else: - # Inputs are local so we do not need to download them. - output_filename = object_name with open(output_filename) as _txt_file: txt = _txt_file.read() diff --git a/tests/test_convert_text_to_mds.py b/tests/test_convert_text_to_mds.py index 2d4878ebbb..ab8c25bc2d 100644 --- a/tests/test_convert_text_to_mds.py +++ b/tests/test_convert_text_to_mds.py @@ -188,6 +188,37 @@ def test_single_and_multi_process(merge_shard_groups: Mock, assert n_tokens == expected_n_tokens +def test_local_path(tmp_path: pathlib.Path): + # Input/output folders + input_folder = tmp_path / 'input' + output_folder = tmp_path / 'output' + + # Create input text data + os.makedirs(input_folder, exist_ok=True) + with open(input_folder / 'test.txt', 'w') as f: + f.write('test') + + # Convert text data to mds + convert_text_to_mds( + tokenizer_name='mosaicml/mpt-7b', + output_folder=str(output_folder), + input_folder=str(input_folder), + concat_tokens=1, + eos_text='', + bos_text='', + no_wrap=False, + compression='zstd', + processes=1, + args_str='Namespace()', + reprocess=False, + ) + + # Make sure all the files exist as expected. + assert os.path.exists(output_folder / '.text_to_mds_conversion_done') + assert os.path.exists(output_folder / 'index.json') + assert os.path.exists(output_folder / 'shard.00000.mds.zstd') + + def test_is_already_processed(tmp_path: pathlib.Path): tmp_path_str = str(tmp_path) args_str = 'Namespace(x = 5)'